server.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Copyright 2020 Contributors to OpenLEADR
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. from aiohttp import web
  13. from openleadr.service import EventService, PollService, RegistrationService, ReportService, \
  14. OptService, VTNService
  15. from openleadr.messaging import create_message
  16. from openleadr import objects
  17. from openleadr import utils
  18. from functools import partial
  19. from datetime import datetime, timedelta, timezone
  20. from collections import deque
  21. import asyncio
  22. import logging
  23. import ssl
  24. import re
  25. logger = logging.getLogger('openleadr')
  26. class OpenADRServer:
  27. _MAP = {'on_created_event': 'event_service',
  28. 'on_request_event': 'event_service',
  29. 'on_register_report': 'report_service',
  30. 'on_create_report': 'report_service',
  31. 'on_created_report': 'report_service',
  32. 'on_request_report': 'report_service',
  33. 'on_update_report': 'report_service',
  34. 'on_poll': 'poll_service',
  35. 'on_query_registration': 'registration_service',
  36. 'on_create_party_registration': 'registration_service',
  37. 'on_cancel_party_registration': 'registration_service'}
  38. def __init__(self, vtn_id, cert=None, key=None, passphrase=None, fingerprint_lookup=None,
  39. show_fingerprint=True, http_port=8080, http_host='127.0.0.1', http_cert=None,
  40. http_key=None, http_key_passphrase=None, http_path_prefix='/OpenADR2/Simple/2.0b',
  41. requested_poll_freq=timedelta(seconds=10), http_ca_file=None):
  42. """
  43. Create a new OpenADR VTN (Server).
  44. :param str vtn_id: An identifier string for this VTN. This is how you identify yourself
  45. to the VENs that talk to you.
  46. :param str cert: Path to the PEM-formatted certificate file that is used to sign outgoing
  47. messages
  48. :param str key: Path to the PEM-formatted private key file that is used to sign outgoing
  49. messages
  50. :param str passphrase: The passphrase used to decrypt the private key file
  51. :param callable fingerprint_lookup: A callable that receives a ven_id and should return the
  52. registered fingerprint for that VEN. You should receive
  53. these fingerprints outside of OpenADR and configure them
  54. manually.
  55. :param bool show_fingerprint: Whether to print the fingerprint to your stdout on startup.
  56. Defaults to True.
  57. :param int http_port: The port that the web server is exposed on (default: 8080)
  58. :param str http_host: The host or IP address to bind the server to (default: 127.0.0.1).
  59. :param str http_cert: The path to the PEM certificate for securing HTTP traffic.
  60. :param str http_key: The path to the PEM private key for securing HTTP traffic.
  61. :param str http_ca_file: The path to the CA-file that client certificates are checked against.
  62. :param str http_key_passphrase: The passphrase for the HTTP private key.
  63. """
  64. # Set up the message queues
  65. self.message_queues = {}
  66. self.app = web.Application()
  67. self.services = {'event_service': EventService(vtn_id, message_queues=self.message_queues),
  68. 'report_service': ReportService(vtn_id, message_queues=self.message_queues),
  69. 'poll_service': PollService(vtn_id, message_queues=self.message_queues),
  70. 'opt_service': OptService(vtn_id),
  71. 'registration_service': RegistrationService(vtn_id,
  72. poll_freq=requested_poll_freq)}
  73. if http_path_prefix[-1] == "/":
  74. http_path_prefix = http_path_prefix[:-1]
  75. self.app.add_routes([web.post(f"{http_path_prefix}/{s.__service_name__}", s.handler)
  76. for s in self.services.values()])
  77. self.http_port = http_port
  78. self.http_host = http_host
  79. self.http_path_prefix = http_path_prefix
  80. # Create SSL context for running the server
  81. if http_cert and http_key:
  82. self.ssl_context = ssl.create_default_context(cafile=http_ca_file,
  83. purpose=ssl.Purpose.CLIENT_AUTH)
  84. self.ssl_context.verify_mode = ssl.CERT_REQUIRED
  85. self.ssl_context.load_cert_chain(http_cert, http_key, http_key_passphrase)
  86. else:
  87. self.ssl_context = None
  88. # Configure message signing
  89. if cert and key:
  90. with open(cert, "rb") as file:
  91. cert = file.read()
  92. with open(key, "rb") as file:
  93. key = file.read()
  94. if show_fingerprint:
  95. print("")
  96. print("*" * 80)
  97. print("Your VTN Certificate Fingerprint is "
  98. f"{utils.certificate_fingerprint(cert)}".center(80))
  99. print("Please deliver this fingerprint to the VENs that connect to you.".center(80))
  100. print("You do not need to keep this a secret.".center(80))
  101. print("*" * 80)
  102. print("")
  103. VTNService._create_message = partial(create_message, cert=cert, key=key,
  104. passphrase=passphrase)
  105. VTNService.fingerprint_lookup = staticmethod(fingerprint_lookup)
  106. self.__setattr__ = self.add_handler
  107. def run(self):
  108. """
  109. Starts the asyncio-loop and runs the server in it. This function is
  110. blocking. For other ways to run the server in a more flexible context,
  111. please refer to the `aiohttp documentation
  112. <https://docs.aiohttp.org/en/stable/web_advanced.html#aiohttp-web-app-runners>`_.
  113. """
  114. web.run_app(self.app)
  115. async def run_async(self):
  116. """
  117. Starts the server in an already-running asyncio loop.
  118. """
  119. self.app_runner = web.AppRunner(self.app)
  120. await self.app_runner.setup()
  121. site = web.TCPSite(self.app_runner,
  122. port=self.http_port,
  123. host=self.http_host,
  124. ssl_context=self.ssl_context)
  125. await site.start()
  126. protocol = 'https' if self.ssl_context else 'http'
  127. print("")
  128. print("*" * 80)
  129. print("Your VTN Server is now running at ".center(80))
  130. print(f"{protocol}://{self.http_host}:{self.http_port}{self.http_path_prefix}".center(80))
  131. print("*" * 80)
  132. print("")
  133. async def run_async(self):
  134. await self.run()
  135. async def stop(self):
  136. delayed_call_tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith('DelayedCall')]
  137. for task in delayed_call_tasks:
  138. task.cancel()
  139. await self.app_runner.cleanup()
  140. def add_event(self, ven_id, signal_name, signal_type, intervals, callback, targets=None,
  141. targets_by_type=None, target=None, market_context="oadr://unknown.context",
  142. notification_period=None, ramp_up_period=None, recovery_period=None):
  143. """
  144. Convenience method to add an event with a single signal.
  145. :param str ven_id: The ven_id to whom this event must be delivered.
  146. :param str signal_name: The OpenADR name of the signal; one of openleadr.objects.SIGNAL_NAME
  147. :param str signal_type: The OpenADR type of the signal; one of openleadr.objects.SIGNAL_TYPE
  148. :param str intervals: A list of intervals with a dtstart, duration and payload member.
  149. :param str callback: A callback function for when your event has been accepted (optIn) or refused (optOut).
  150. :param list targets: A list of Targets that this Event applies to.
  151. :param target: A single target for this event.
  152. :param dict targets_by_type: A dict of targets, grouped by type.
  153. :param str market_context: A URI for the DR program that this event belongs to.
  154. :param timedelta notification_period: The Notification period for the Event's Active Period.
  155. :param timedelta ramp_up_period: The Ramp Up period for the Event's Active Period.
  156. :param timedelta recovery_period: The Recovery period for the Event's Active Period.
  157. If you don't provide a target using any of the three arguments, the target will be set to the given ven_id.
  158. """
  159. if self.services['event_service'].polling_method == 'external':
  160. logger.error("You cannot use the add_event method after you assign your own on_poll "
  161. "handler. If you use your own on_poll handler, you are responsible for "
  162. "delivering events from that handler. If you want to use OpenLEADRs "
  163. "message queuing system, you should not assign an on_poll handler. "
  164. "Your Event will NOT be added.")
  165. return
  166. if not re.match(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?", market_context):
  167. raise ValueError("The Market Context must be a valid URI.")
  168. event_id = utils.generate_id()
  169. # Figure out the target for this Event
  170. if target is None and targets is None and targets_by_type is None:
  171. targets = [{'ven_id': ven_id}]
  172. elif target is not None:
  173. targets = [target]
  174. elif targets_by_type is not None:
  175. targets = utils.ungroup_targets_by_type(targets_by_type)
  176. if not isinstance(targets, list):
  177. targets = [targets]
  178. event_descriptor = objects.EventDescriptor(event_id=event_id,
  179. modification_number=0,
  180. market_context=market_context,
  181. event_status="far",
  182. created_date_time=datetime.now(timezone.utc))
  183. event_signal = objects.EventSignal(intervals=intervals,
  184. signal_name=signal_name,
  185. signal_type=signal_type,
  186. signal_id=utils.generate_id(),
  187. targets=targets)
  188. # Make sure the intervals carry timezone-aware timestamps
  189. for interval in intervals:
  190. if utils.getmember(interval, 'dtstart').tzinfo is None:
  191. utils.setmember(interval, 'dtstart',
  192. utils.getmember(interval, 'dtstart').astimezone(timezone.utc))
  193. logger.warning("You supplied a naive datetime object to your interval's dtstart. "
  194. "This will be interpreted as a timestamp in your local timezone "
  195. "and then converted to UTC before sending. Please supply timezone-"
  196. "aware timestamps like datetime.datetime.new(timezone.utc) or "
  197. "datetime.datetime(..., tzinfo=datetime.timezone.utc)")
  198. active_period = utils.get_active_period_from_intervals(intervals, False)
  199. active_period.ramp_up_period = ramp_up_period
  200. active_period.notification_period = notification_period
  201. active_period.recovery_period = recovery_period
  202. event = objects.Event(active_period=active_period,
  203. event_descriptor=event_descriptor,
  204. event_signals=[event_signal],
  205. targets=targets)
  206. if ven_id not in self.message_queues:
  207. self.message_queues[ven_id] = deque()
  208. self.message_queues[ven_id].append(event)
  209. self.services['event_service'].pending_events[event_id] = (event, callback)
  210. return event_id
  211. def add_raw_event(self, ven_id, event, callback):
  212. """
  213. Add a new event to the queue for a specific VEN.
  214. :param str ven_id: The ven_id to which this event should be distributed.
  215. :param dict event: The event (as a dict or as a objects.Event instance)
  216. that contains the event details.
  217. :param callable callback: A callback that will receive the opt status for this event.
  218. This callback receives ven_id, event_id, opt_type as its arguments.
  219. """
  220. if ven_id not in self.message_queues:
  221. self.message_queues[ven_id] = deque()
  222. event_id = utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_id')
  223. self.message_queues[ven_id].append(event)
  224. self.services['event_service'].pending_events[event_id] = (event, callback)
  225. return event_id
  226. def add_handler(self, name, func):
  227. """
  228. Add a handler to the OpenADRServer.
  229. :param str name: The name for this handler. Should be one of: on_created_event,
  230. on_request_event, on_register_report, on_create_report,
  231. on_created_report, on_request_report, on_update_report, on_poll,
  232. on_query_registration, on_create_party_registration,
  233. on_cancel_party_registration.
  234. :param callable func: A function or coroutine that handles this type of occurrence.
  235. It receives the message, and should return the contents of a response.
  236. """
  237. logger.debug(f"Adding handler: {name} {func}")
  238. if name in self._MAP:
  239. setattr(self.services[self._MAP[name]], name, func)
  240. if name == 'on_poll':
  241. self.services['poll_service'].polling_method = 'external'
  242. self.services['event_service'].polling_method = 'external'
  243. else:
  244. raise NameError(f"Unknown handler {name}. "
  245. f"Correct handler names are: {self._MAP.keys()}")