server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Copyright 2020 Contributors to OpenLEADR
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. from aiohttp import web
  13. from openleadr.service import EventService, PollService, RegistrationService, ReportService, \
  14. OptService, VTNService
  15. from openleadr.messaging import create_message
  16. from openleadr import objects
  17. from openleadr import utils
  18. from functools import partial
  19. from datetime import datetime, timedelta, timezone
  20. from collections import deque
  21. import asyncio
  22. import logging
  23. import ssl
  24. import re
  25. import inspect
  26. logger = logging.getLogger('openleadr')
  27. class OpenADRServer:
  28. _MAP = {'on_created_event': 'event_service',
  29. 'on_request_event': 'event_service',
  30. 'on_register_report': 'report_service',
  31. 'on_create_report': 'report_service',
  32. 'on_created_report': 'report_service',
  33. 'on_request_report': 'report_service',
  34. 'on_update_report': 'report_service',
  35. 'on_poll': 'poll_service',
  36. 'on_query_registration': 'registration_service',
  37. 'on_create_party_registration': 'registration_service',
  38. 'on_cancel_party_registration': 'registration_service'}
  39. def __init__(self, vtn_id, cert=None, key=None, passphrase=None, fingerprint_lookup=None,
  40. show_fingerprint=True, http_port=8080, http_host='127.0.0.1', http_cert=None,
  41. http_key=None, http_key_passphrase=None, http_path_prefix='/OpenADR2/Simple/2.0b',
  42. requested_poll_freq=timedelta(seconds=10), http_ca_file=None):
  43. """
  44. Create a new OpenADR VTN (Server).
  45. :param str vtn_id: An identifier string for this VTN. This is how you identify yourself
  46. to the VENs that talk to you.
  47. :param str cert: Path to the PEM-formatted certificate file that is used to sign outgoing
  48. messages
  49. :param str key: Path to the PEM-formatted private key file that is used to sign outgoing
  50. messages
  51. :param str passphrase: The passphrase used to decrypt the private key file
  52. :param callable fingerprint_lookup: A callable that receives a ven_id and should return the
  53. registered fingerprint for that VEN. You should receive
  54. these fingerprints outside of OpenADR and configure them
  55. manually.
  56. :param bool show_fingerprint: Whether to print the fingerprint to your stdout on startup.
  57. Defaults to True.
  58. :param int http_port: The port that the web server is exposed on (default: 8080)
  59. :param str http_host: The host or IP address to bind the server to (default: 127.0.0.1).
  60. :param str http_cert: The path to the PEM certificate for securing HTTP traffic.
  61. :param str http_key: The path to the PEM private key for securing HTTP traffic.
  62. :param str http_ca_file: The path to the CA-file that client certificates are checked against.
  63. :param str http_key_passphrase: The passphrase for the HTTP private key.
  64. """
  65. # Set up the message queues
  66. self.message_queues = {}
  67. self.app = web.Application()
  68. self.services = {'event_service': EventService(vtn_id, message_queues=self.message_queues),
  69. 'report_service': ReportService(vtn_id, message_queues=self.message_queues),
  70. 'poll_service': PollService(vtn_id, message_queues=self.message_queues),
  71. 'opt_service': OptService(vtn_id),
  72. 'registration_service': RegistrationService(vtn_id,
  73. poll_freq=requested_poll_freq)}
  74. if http_path_prefix[-1] == "/":
  75. http_path_prefix = http_path_prefix[:-1]
  76. self.app.add_routes([web.post(f"{http_path_prefix}/{s.__service_name__}", s.handler)
  77. for s in self.services.values()])
  78. self.http_port = http_port
  79. self.http_host = http_host
  80. self.http_path_prefix = http_path_prefix
  81. # Create SSL context for running the server
  82. if http_cert and http_key:
  83. self.ssl_context = ssl.create_default_context(cafile=http_ca_file,
  84. purpose=ssl.Purpose.CLIENT_AUTH)
  85. self.ssl_context.verify_mode = ssl.CERT_REQUIRED
  86. self.ssl_context.load_cert_chain(http_cert, http_key, http_key_passphrase)
  87. else:
  88. self.ssl_context = None
  89. # Configure message signing
  90. if cert and key:
  91. with open(cert, "rb") as file:
  92. cert = file.read()
  93. with open(key, "rb") as file:
  94. key = file.read()
  95. if show_fingerprint:
  96. print("")
  97. print("*" * 80)
  98. print("Your VTN Certificate Fingerprint is "
  99. f"{utils.certificate_fingerprint(cert)}".center(80))
  100. print("Please deliver this fingerprint to the VENs that connect to you.".center(80))
  101. print("You do not need to keep this a secret.".center(80))
  102. print("*" * 80)
  103. print("")
  104. VTNService._create_message = partial(create_message, cert=cert, key=key,
  105. passphrase=passphrase)
  106. VTNService.fingerprint_lookup = staticmethod(fingerprint_lookup)
  107. self.__setattr__ = self.add_handler
  108. async def run(self):
  109. """
  110. Starts the server in an already-running asyncio loop.
  111. """
  112. self.app_runner = web.AppRunner(self.app)
  113. await self.app_runner.setup()
  114. site = web.TCPSite(self.app_runner,
  115. port=self.http_port,
  116. host=self.http_host,
  117. ssl_context=self.ssl_context)
  118. await site.start()
  119. protocol = 'https' if self.ssl_context else 'http'
  120. print("")
  121. print("*" * 80)
  122. print("Your VTN Server is now running at ".center(80))
  123. print(f"{protocol}://{self.http_host}:{self.http_port}{self.http_path_prefix}".center(80))
  124. print("*" * 80)
  125. print("")
  126. async def run_async(self):
  127. await self.run()
  128. async def stop(self):
  129. delayed_call_tasks = [task for task in asyncio.all_tasks()
  130. if task.get_name().startswith('DelayedCall')]
  131. for task in delayed_call_tasks:
  132. task.cancel()
  133. await self.app_runner.cleanup()
  134. def add_event(self, ven_id, signal_name, signal_type, intervals, callback=None, event_id=None,
  135. targets=None, targets_by_type=None, target=None, response_required='always',
  136. market_context="oadr://unknown.context", notification_period=None,
  137. ramp_up_period=None, recovery_period=None):
  138. """
  139. Convenience method to add an event with a single signal.
  140. :param str ven_id: The ven_id to whom this event must be delivered.
  141. :param str signal_name: The OpenADR name of the signal; one of openleadr.objects.SIGNAL_NAME
  142. :param str signal_type: The OpenADR type of the signal; one of openleadr.objects.SIGNAL_TYPE
  143. :param str intervals: A list of intervals with a dtstart, duration and payload member.
  144. :param str callback: A callback function for when your event has been accepted (optIn) or refused (optOut).
  145. :param list targets: A list of Targets that this Event applies to.
  146. :param target: A single target for this event.
  147. :param dict targets_by_type: A dict of targets, grouped by type.
  148. :param str market_context: A URI for the DR program that this event belongs to.
  149. :param timedelta notification_period: The Notification period for the Event's Active Period.
  150. :param timedelta ramp_up_period: The Ramp Up period for the Event's Active Period.
  151. :param timedelta recovery_period: The Recovery period for the Event's Active Period.
  152. If you don't provide a target using any of the three arguments, the target will be set to the given ven_id.
  153. """
  154. if self.services['event_service'].polling_method == 'external':
  155. logger.error("You cannot use the add_event method after you assign your own on_poll "
  156. "handler. If you use your own on_poll handler, you are responsible for "
  157. "delivering events from that handler. If you want to use OpenLEADRs "
  158. "message queuing system, you should not assign an on_poll handler. "
  159. "Your Event will NOT be added.")
  160. return
  161. if not re.match(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?", market_context):
  162. raise ValueError("The Market Context must be a valid URI.")
  163. event_id = event_id or utils.generate_id()
  164. if response_required not in ('always', 'never'):
  165. raise ValueError("'response_required' should be either 'always' or 'never'; "
  166. f"you provided '{response_required}'.")
  167. # Figure out the target for this Event
  168. if target is None and targets is None and targets_by_type is None:
  169. targets = [{'ven_id': ven_id}]
  170. elif target is not None:
  171. targets = [target]
  172. elif targets_by_type is not None:
  173. targets = utils.ungroup_targets_by_type(targets_by_type)
  174. if not isinstance(targets, list):
  175. targets = [targets]
  176. event_descriptor = objects.EventDescriptor(event_id=event_id,
  177. modification_number=0,
  178. market_context=market_context,
  179. event_status="far",
  180. created_date_time=datetime.now(timezone.utc))
  181. event_signal = objects.EventSignal(intervals=intervals,
  182. signal_name=signal_name,
  183. signal_type=signal_type,
  184. signal_id=utils.generate_id(),
  185. targets=targets)
  186. # Make sure the intervals carry timezone-aware timestamps
  187. for interval in intervals:
  188. if utils.getmember(interval, 'dtstart').tzinfo is None:
  189. utils.setmember(interval, 'dtstart',
  190. utils.getmember(interval, 'dtstart').astimezone(timezone.utc))
  191. logger.warning("You supplied a naive datetime object to your interval's dtstart. "
  192. "This will be interpreted as a timestamp in your local timezone "
  193. "and then converted to UTC before sending. Please supply timezone-"
  194. "aware timestamps like datetime.datetime.new(timezone.utc) or "
  195. "datetime.datetime(..., tzinfo=datetime.timezone.utc)")
  196. active_period = utils.get_active_period_from_intervals(intervals, False)
  197. active_period.ramp_up_period = ramp_up_period
  198. active_period.notification_period = notification_period
  199. active_period.recovery_period = recovery_period
  200. event = objects.Event(active_period=active_period,
  201. event_descriptor=event_descriptor,
  202. event_signals=[event_signal],
  203. targets=targets,
  204. response_required=response_required)
  205. self.add_raw_event(ven_id=ven_id, event=event, callback=callback)
  206. return event_id
  207. def add_raw_event(self, ven_id, event, callback=None):
  208. """
  209. Add a new event to the queue for a specific VEN.
  210. :param str ven_id: The ven_id to which this event should be distributed.
  211. :param dict event: The event (as a dict or as a objects.Event instance)
  212. that contains the event details.
  213. :param callable callback: A callback that will receive the opt status for this event.
  214. This callback receives ven_id, event_id, opt_type as its arguments.
  215. """
  216. if utils.getmember(event, 'response_required') == 'always':
  217. if callback is None:
  218. logger.warning("You did not provide a 'callback', which means you won't know if the "
  219. "VEN will opt in or opt out of your event. You should consider adding "
  220. "a callback for this.")
  221. elif not asyncio.isfuture(callback):
  222. args = inspect.signature(callback).parameters
  223. if not all(['ven_id' in args, 'event_id' in args, 'opt_type' in args]):
  224. raise ValueError("The 'callback' must have at least the following parameters: "
  225. "'ven_id' (str), 'event_id' (str), 'opt_type' (str). Please fix "
  226. "your 'callback' handler.")
  227. if ven_id not in self.message_queues:
  228. self.message_queues[ven_id] = deque()
  229. event_id = utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_id')
  230. self.message_queues[ven_id].append(event)
  231. if callback is not None:
  232. self.services['event_service'].pending_events[event_id] = (event, callback)
  233. if utils.getmember(event, 'response_required') == 'never':
  234. self.services['event_service'].schedule_event_updates(ven_id, event)
  235. return event_id
  236. def add_handler(self, name, func):
  237. """
  238. Add a handler to the OpenADRServer.
  239. :param str name: The name for this handler. Should be one of: on_created_event,
  240. on_request_event, on_register_report, on_create_report,
  241. on_created_report, on_request_report, on_update_report, on_poll,
  242. on_query_registration, on_create_party_registration,
  243. on_cancel_party_registration.
  244. :param callable func: A function or coroutine that handles this type of occurrence.
  245. It receives the message, and should return the contents of a response.
  246. """
  247. logger.debug(f"Adding handler: {name} {func}")
  248. if name in self._MAP:
  249. setattr(self.services[self._MAP[name]], name, func)
  250. if name == 'on_poll':
  251. self.services['poll_service'].polling_method = 'external'
  252. self.services['event_service'].polling_method = 'external'
  253. else:
  254. raise NameError(f"""Unknown handler '{name}'. """
  255. f"""Correct handler names are: '{"', '".join(self._MAP.keys())}'.""")