ソースを参照

Merge pull request #49 from OpenLEADR/distribute-event-redesign

Distribute event redesign
Stan Janssen 3 年 前
コミット
fb338a885c

+ 1 - 1
docs/roadmap.rst

@@ -95,7 +95,7 @@ New features:
 
 - Events now cycle through the correct 'far', 'near', 'active', 'completed'.
 - The Client now implements the ``on_update_event handler``, so that you can catch these event updates separately from the regular event messages.
-- Added support for the ramp_up_duration parameter on the ``server.add_event`` method.
+- Added support for the ramp_up_period parameter on the ``server.add_event`` method.
 
 Bug fixes:
 

+ 29 - 3
openleadr/client.py

@@ -80,6 +80,8 @@ class OpenADRClient:
         self.scheduler = AsyncIOScheduler()
         self.client_session = None
         self.report_queue_task = None
+
+        self.received_events = {}               # Holds the events that we received.
         self.responded_events = {}              # Holds the events that we already saw.
 
         self.cert_path = cert
@@ -138,6 +140,9 @@ class OpenADRClient:
         self.scheduler.add_job(self._poll,
                                trigger='cron',
                                **cron_config)
+        self.scheduler.add_job(self._event_cleanup,
+                               trigger='interval',
+                               seconds=300)
         self.scheduler.start()
 
     async def stop(self):
@@ -312,6 +317,7 @@ class OpenADRClient:
                                                        market_context=market_context)
         self.report_callbacks[(report.report_specifier_id, r_id)] = callback
         report.report_descriptions.append(report_description)
+        return report_specifier_id, r_id
 
     ###########################################################################
     #                                                                         #
@@ -745,9 +751,17 @@ class OpenADRClient:
             for event in message['events']:
                 event_id = event['event_descriptor']['event_id']
                 event_status = event['event_descriptor']['event_status']
-                if event_id in self.responded_events:
-                    result = self.on_update_event(event)
+                modification_number = event['event_descriptor']['modification_number']
+                if event_id in self.received_events:
+                    if self.received_events[event_id]['event_descriptor']['modification_number'] == modification_number:
+                        # Re-submit the same opt type as we already had previously
+                        result = self.responded_events[event_id]
+                    else:
+                        # Wait for the result of the on_update_event handler
+                        result = self.on_update_event(event)
                 else:
+                    # Wait for the result of the on_event
+                    self.received_events[event_id] = event
                     result = self.on_event(event)
                 if asyncio.iscoroutine(result):
                     result = await result
@@ -772,7 +786,9 @@ class OpenADRClient:
                             'request_id': message['request_id'],
                             'modification_number': 1,
                             'event_id': events[i]['event_descriptor']['event_id']}
-                           for i, event in enumerate(events) if event['response_required'] == 'always']
+                           for i, event in enumerate(events)
+                           if event['response_required'] == 'always'
+                           and not utils.determine_event_status(event['active_period']) == 'completed']
 
         if len(event_responses) > 0:
             response = {'response_code': 200,
@@ -788,6 +804,16 @@ class OpenADRClient:
         else:
             logger.info("Not sending any event responses, because a response was not required/allowed by the VTN.")
 
+    async def _event_cleanup(self):
+        """
+        Periodic task that will clean up completed events in our memory.
+        """
+        print("Checking for stale events")
+        for event in list(self.received_events):
+            if utils.determine_event_status(self.received_events[event]['active_period']) == 'completed':
+                logger.debug(f"Removing event {event} because it is completed.")
+                self.received_events.pop(event)
+
     async def _poll(self):
         logger.debug("Now polling for new messages")
         response_type, response_payload = await self.poll()

+ 42 - 18
openleadr/server.py

@@ -16,13 +16,12 @@
 
 from aiohttp import web
 from openleadr.service import EventService, PollService, RegistrationService, ReportService, \
-                              OptService, VTNService
+                              VTNService
 from openleadr.messaging import create_message
 from openleadr import objects
 from openleadr import utils
 from functools import partial
 from datetime import datetime, timedelta, timezone
-from collections import deque
 import asyncio
 import inspect
 import logging
@@ -76,19 +75,25 @@ class OpenADRServer:
         :param str http_key_passphrase: The passphrase for the HTTP private key.
         """
         # Set up the message queues
-        self.message_queues = {}
 
         self.app = web.Application()
-        self.services = {'event_service': EventService(vtn_id, message_queues=self.message_queues),
-                         'report_service': ReportService(vtn_id, message_queues=self.message_queues),
-                         'poll_service': PollService(vtn_id, message_queues=self.message_queues),
-                         'opt_service': OptService(vtn_id),
-                         'registration_service': RegistrationService(vtn_id,
-                                                                     poll_freq=requested_poll_freq)}
+        self.services = {}
+        self.services['event_service'] = EventService(vtn_id)
+        self.services['report_service'] = ReportService(vtn_id)
+        self.services['poll_service'] = PollService(vtn_id)
+        self.services['registration_service'] = RegistrationService(vtn_id, poll_freq=requested_poll_freq)
+
+        # Register the other services with the poll service
+        self.services['poll_service'].event_service = self.services['event_service']
+        self.services['poll_service'].report_service = self.services['report_service']
+
+        # Set up the HTTP handlers for the services
         if http_path_prefix[-1] == "/":
             http_path_prefix = http_path_prefix[:-1]
         self.app.add_routes([web.post(f"{http_path_prefix}/{s.__service_name__}", s.handler)
                              for s in self.services.values()])
+
+        # Configure the web server
         self.http_port = http_port
         self.http_host = http_host
         self.http_path_prefix = http_path_prefix
@@ -155,7 +160,7 @@ class OpenADRServer:
     def add_event(self, ven_id, signal_name, signal_type, intervals, callback=None, event_id=None,
                   targets=None, targets_by_type=None, target=None, response_required='always',
                   market_context="oadr://unknown.context", notification_period=None,
-                  ramp_up_period=None, recovery_period=None):
+                  ramp_up_period=None, recovery_period=None, signal_target_mrid=None):
         """
         Convenience method to add an event with a single signal.
 
@@ -207,8 +212,7 @@ class OpenADRServer:
         event_signal = objects.EventSignal(intervals=intervals,
                                            signal_name=signal_name,
                                            signal_type=signal_type,
-                                           signal_id=utils.generate_id(),
-                                           targets=targets)
+                                           signal_id=utils.generate_id())
 
         # Make sure the intervals carry timezone-aware timestamps
         for interval in intervals:
@@ -253,14 +257,18 @@ class OpenADRServer:
                                      "'ven_id' (str), 'event_id' (str), 'opt_type' (str). Please fix "
                                      "your 'callback' handler.")
 
-        if ven_id not in self.message_queues:
-            self.message_queues[ven_id] = deque()
         event_id = utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_id')
-        self.message_queues[ven_id].append(event)
+        # Create the event queue if it does not exist yet
+        if ven_id not in self.events:
+            self.events[ven_id] = []
+
+        # Add event to the queue
+        self.events[ven_id].append(event)
+        self.events_updated[ven_id] = True
+
+        # Add the callback for the response to this event
         if callback is not None:
-            self.services['event_service'].pending_events[event_id] = (event, callback)
-        if utils.getmember(event, 'response_required') == 'never':
-            self.services['event_service'].schedule_event_updates(ven_id, event)
+            self.event_callbacks[event_id] = (event, callback)
         return event_id
 
     def add_handler(self, name, func):
@@ -284,3 +292,19 @@ class OpenADRServer:
         else:
             raise NameError(f"""Unknown handler '{name}'. """
                             f"""Correct handler names are: '{"', '".join(self._MAP.keys())}'.""")
+
+    @property
+    def registered_reports(self):
+        return self.services['report_service'].registered_reports
+
+    @property
+    def events(self):
+        return self.services['event_service'].events
+
+    @property
+    def events_updated(self):
+        return self.services['poll_service'].events_updated
+
+    @property
+    def event_callbacks(self):
+        return self.services['event_service'].event_callbacks

+ 37 - 88
openleadr/service/event_service.py

@@ -16,49 +16,51 @@
 
 from . import service, handler, VTNService
 import asyncio
-from openleadr import objects, utils, enums
+from openleadr import utils, errors
 import logging
-import sys
-from datetime import datetime, timezone
-from functools import partial
-from dataclasses import asdict
 logger = logging.getLogger('openleadr')
 
 
 @service('EiEvent')
 class EventService(VTNService):
 
-    def __init__(self, vtn_id, polling_method='internal', message_queues=None):
+    def __init__(self, vtn_id, polling_method='internal'):
         super().__init__(vtn_id)
         self.polling_method = polling_method
-        self.message_queues = message_queues
-        self.pending_events = {}        # Holds the event callbacks
-        self.running_events = {}        # Holds the event callbacks for accepted events
+        self.events = {}
+        self.completed_event_ids = {}   # Holds the ids of completed events
+        self.event_callbacks = {}
+        self.event_opt_types = {}
 
     @handler('oadrRequestEvent')
     async def request_event(self, payload):
         """
         The VEN requests us to send any events we have.
         """
-        if self.polling_method == 'external':
+        ven_id = payload['ven_id']
+        if self.polling_method == 'internal':
+            if ven_id in self.events and self.events[ven_id]:
+                events = utils.order_events(self.events[ven_id])
+                for event in events:
+                    event_status = utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_status')
+                    # Pop the event from the events so that this is the last time it is communicated
+                    if event_status == 'completed':
+                        self.events[ven_id].pop(self.events[ven_id].index(event))
+            else:
+                events = None
+        else:
             result = self.on_request_event(ven_id=payload['ven_id'])
             if asyncio.iscoroutine(result):
                 result = await result
-        elif payload['ven_id'] in self.message_queues:
-            queue = self.message_queues[payload['ven_id']]
-            result = utils.get_next_event_from_deque(queue)
-        else:
-            return 'oadrResponse', {}
+            if result is None:
+                events = None
+            else:
+                events = utils.order_events(result)
 
-        if result is None:
+        if events is None:
             return 'oadrResponse', {}
-        if isinstance(result, dict) and 'event_descriptor' in result:
-            return 'oadrDistributeEvent', {'events': [result]}
-        elif isinstance(result, objects.Event):
-            return 'oadrDistributeEvent', {'events': [asdict(result)]}
-
-        logger.warning("Could not determine type of message "
-                       f"in response to oadrRequestEvent: {result}")
+        else:
+            return 'oadrDistributeEvent', {'events': events}
         return 'oadrResponse', result
 
     def on_request_event(self, ven_id):
@@ -81,24 +83,19 @@ class EventService(VTNService):
             for event_response in payload['event_responses']:
                 event_id = event_response['event_id']
                 opt_type = event_response['opt_type']
-                if event_response['event_id'] in self.pending_events:
-                    event, callback = self.pending_events.pop(event_id)
-                    if isinstance(callback, asyncio.Future):
-                        callback.set_result(opt_type)
-                    else:
-                        result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
-                        if asyncio.iscoroutine(result):
-                            result = await result
-                    if opt_type == 'optIn':
-                        self.running_events[event_id] = (event, callback)
-                        self.schedule_event_updates(ven_id, event)
-                elif event_response['event_id'] in self.running_events:
-                    event, callback = self.running_events.pop(event_id)
+                if event_id not in [utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_id')
+                                    for event in self.events[ven_id]] + self.completed_event_ids.get(ven_id, []):
+                    raise errors.InvalidIdError
+                if event_response['event_id'] in self.event_callbacks:
+                    event, callback = self.event_callbacks.pop(event_id)
                     if isinstance(callback, asyncio.Future):
-                        logger.warning(f"Got a second response '{opt_type}' from ven '{ven_id}' "
-                                       f"to event '{event_id}', which we cannot use because the "
-                                       "callback future you provided was already completed during "
-                                       "the first response.")
+                        if callback.done():
+                            logger.warning(f"Got a second response '{opt_type}' from ven '{ven_id}' "
+                                           f"to event '{event_id}', which we cannot use because the "
+                                           "callback future you provided was already completed during "
+                                           "the first response.")
+                        else:
+                            callback.set_result(opt_type)
                     else:
                         result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
                         if asyncio.iscoroutine(result):
@@ -118,51 +115,3 @@ class EventService(VTNService):
                        "handler will receive a ven_id, event_id and opt_status. "
                        "You don't need to return anything from this handler.")
         return None
-
-    def _update_event_status(self, ven_id, event, event_status):
-        """
-        Update the event to the given Status.
-        """
-        event.event_descriptor.event_status = event_status
-        if event_status == enums.EVENT_STATUS.CANCELLED:
-            event.event_descriptor.modification_number += 1
-        self.message_queues[ven_id].append(event)
-
-    def schedule_event_updates(self, ven_id, event):
-        """
-        Schedules the event updates.
-        """
-        loop = asyncio.get_event_loop()
-        now = datetime.now(timezone.utc)
-        active_period = event.active_period
-
-        # Named tasks is only supported in Python 3.8+
-        if sys.version_info.minor >= 8:
-            named_tasks = True
-        else:
-            named_tasks = False
-            name = {}
-
-        # Schedule status update to 'near' if applicable
-        if active_period.ramp_up_period is not None and event.event_descriptor.event_status == 'far':
-            ramp_up_start_delay = (active_period.dtstart - active_period.ramp_up_period) - now
-            update_coro = partial(self._update_event_status, ven_id, event, 'near')
-            if named_tasks:
-                name = {'name': f'DelayedCall-EventStatusToNear-{event.event_descriptor.event_id}'}
-            loop.create_task(utils.delayed_call(func=update_coro, delay=ramp_up_start_delay), **name)
-
-        # Schedule status update to 'active'
-        if event.event_descriptor.event_status in ('near', 'far'):
-            active_start_delay = active_period.dtstart - now
-            update_coro = partial(self._update_event_status, ven_id, event, 'active')
-            if named_tasks:
-                name = {'name': f'DelayedCall-EventStatusToActive-{event.event_descriptor.event_id}'}
-            loop.create_task(utils.delayed_call(func=update_coro, delay=active_start_delay), **name)
-
-        # Schedule status update to 'completed'
-        if event.event_descriptor.event_status in ('near', 'far', 'active'):
-            active_end_delay = active_period.dtstart + active_period.duration - now
-            update_coro = partial(self._update_event_status, ven_id, event, 'completed')
-            if named_tasks:
-                name = {'name': f'DelayedCall-EventStatusToActive-{event.event_descriptor.event_id}'}
-            loop.create_task(utils.delayed_call(func=update_coro, delay=active_end_delay), **name)

+ 10 - 7
openleadr/service/poll_service.py

@@ -102,10 +102,13 @@ logger = logging.getLogger('openleadr')
 @service('OadrPoll')
 class PollService(VTNService):
 
-    def __init__(self, vtn_id, polling_method='internal', message_queues=None):
+    def __init__(self, vtn_id, polling_method='internal', event_service=None, report_service=None):
         super().__init__(vtn_id)
         self.polling_method = polling_method
-        self.message_queues = message_queues
+        self.events_updated = {}
+        self.report_requests = {}
+        self.event_service = event_service
+        self.report_service = report_service
 
     @handler('oadrPoll')
     async def poll(self, payload):
@@ -115,13 +118,13 @@ class PollService(VTNService):
         """
         if self.polling_method == 'external':
             result = self.on_poll(ven_id=payload['ven_id'])
-        elif payload['ven_id'] in self.message_queues:
-            try:
-                result = self.message_queues[payload['ven_id']].popleft()
-            except IndexError:
-                return 'oadrResponse', {}
+        elif self.events_updated.get(payload['ven_id']):
+            # Send a oadrDistributeEvent whenever the events were updated
+            result = await self.event_service.request_event({'ven_id': payload['ven_id']})
+            self.events_updated[payload['ven_id']] = False
         else:
             return 'oadrResponse', {}
+
         if asyncio.iscoroutine(result):
             result = await result
         if result is None:

+ 2 - 1
openleadr/service/vtn_service.py

@@ -72,7 +72,8 @@ class VTNService:
                 response_type, response_payload = await self.handle_message(message_type,
                                                                             message_payload)
             except Exception as err:
-                logger.error("An exception occurred during the execution of your handler: "
+                logger.error("An exception occurred during the execution of your "
+                             f"{self.__class__.__name__} handler: "
                              f"{err.__class__.__name__}: {err}")
                 raise err
 

+ 84 - 11
openleadr/utils.py

@@ -568,19 +568,18 @@ def get_active_period_from_intervals(intervals, as_dict=True):
 
 
 def determine_event_status(active_period):
-    if is_dataclass(active_period):
-        active_period = asdict(active_period)
     now = datetime.now(timezone.utc)
-    if active_period['dtstart'].tzinfo is None:
-        active_period['dtstart'] = active_period['dtstart'].astimezone(timezone.utc)
-    active_period_start = active_period['dtstart']
-    active_period_end = active_period['dtstart'] + active_period['duration']
+    active_period_start = getmember(active_period, 'dtstart')
+    if active_period_start.tzinfo is None:
+        active_period_start = active_period_start.astimezone(timezone.utc)
+        setmember(active_period, 'dtstart', active_period_start)
+    active_period_end = active_period_start + getmember(active_period, 'duration')
     if now >= active_period_end:
         return 'completed'
     if now >= active_period_start:
         return 'active'
-    if active_period.get('ramp_up_duration') is not None:
-        ramp_up_start = active_period_start - active_period['ramp_up_duration']
+    if getmember(active_period, 'ramp_up_period', None) is not None:
+        ramp_up_start = active_period_start - getmember(active_period, 'ramp_up_period')
         if now >= ramp_up_start:
             return 'near'
     return 'far'
@@ -614,14 +613,20 @@ def hasmember(obj, member):
     return False
 
 
-def getmember(obj, member):
+def getmember(obj, member, missing='_RAISE_'):
     """
     Get a member from a dict or dataclass
     """
     if is_dataclass(obj):
-        return getattr(obj, member)
+        if not missing == '_RAISE_' and not hasattr(obj, member):
+            return missing
+        else:
+            return getattr(obj, member)
     else:
-        return obj[member]
+        if missing == '_RAISE_':
+            return obj[member]
+        else:
+            return obj.get(member, missing)
 
 
 def setmember(obj, member, value):
@@ -727,3 +732,71 @@ def validate_report_request_tuples(list_of_report_requests, full_mode=False):
                                  "(callback, sampling_interval, reporting_interval) tuple, where "
                                  "sampling_interval and reporting_interval are of type datetime.timedelta. "
                                  f"It returned: '{rrq[1:]}'. The third element was not of type timedelta.")
+
+
+async def await_if_required(result):
+    if asyncio.iscoroutine(result):
+        result = await result
+    return result
+
+
+async def gather_if_required(results):
+    if results is None:
+        return results
+    if len(results) > 0:
+        if not any([asyncio.iscoroutine(r) for r in results]):
+            results = results
+        elif all([asyncio.iscoroutine(r) for r in results]):
+            results = await asyncio.gather(*results)
+        else:
+            results = [await await_if_required(result) for result in results]
+    return results
+
+
+def order_events(events, limit=None, offset=None):
+    """
+    Order the events according to the OpenADR rules:
+    - active events before inactive events
+    - high priority before low priority
+    - earlier before later
+    """
+    def event_priority(event):
+        # The default and lowest priority is 0, which we should interpret as a high value.
+        priority = getmember(getmember(event, 'event_descriptor'), 'priority', float('inf'))
+        if priority == 0:
+            priority = float('inf')
+        return priority
+
+    if events is None:
+        return None
+    if isinstance(events, objects.Event):
+        events = [events]
+    elif isinstance(events, dict):
+        events = [events]
+
+    # Update the event statuses
+    for event in events:
+        event_status = determine_event_status(getmember(event, 'active_period'))
+        setmember(getmember(event, 'event_descriptor'), 'event_status', event_status)
+
+    # Short circuit if we only have one event:
+    if len(events) == 1:
+        return events
+
+    # Get all the active events first
+    active_events = [event for event in events if getmember(getmember(event, 'event_descriptor'), 'event_status') == 'active']
+    other_events = [event for event in events if getmember(getmember(event, 'event_descriptor'), 'event_status') != 'active']
+
+    # Sort the active events by priority
+    active_events.sort(key=lambda e: event_priority(e))
+
+    # Sort the active events by start date
+    active_events.sort(key=lambda e: getmember(getmember(e, 'active_period'), 'dtstart'))
+
+    # Sort the non-active events by their start date
+    other_events.sort(key=lambda e: getmember(getmember(e, 'active_period'), 'dtstart'))
+
+    ordered_events = active_events + other_events
+    if limit and offset:
+        return ordered_events[offset:offset+limit]
+    return ordered_events

+ 12 - 4
test/integration_tests/test_event_warnings_errors.py

@@ -308,6 +308,7 @@ async def test_client_warning_no_update_event_handler(caplog):
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_handler('on_create_party_registration', on_create_party_registration)
+    event_accepted_future = asyncio.get_event_loop().create_future()
     server.add_event(ven_id='venid',
                      event_id='test_client_warning_no_update_event_handler',
                      signal_name='simple',
@@ -316,14 +317,21 @@ async def test_client_warning_no_update_event_handler(caplog):
                                  'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
-                     callback=on_event_accepted)
+                     callback=event_accepted_future)
     client = OpenADRClient(ven_name='myven',
                            vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
     client.add_handler('on_event', good_on_event)
-    await server.run_async()
-    # await asyncio.sleep(0.5)
+    print("Starting server")
+    await server.run()
     await client.run()
-    await asyncio.sleep(2)
+    print("Waiting for first event to be accepted...")
+    await event_accepted_future
+
+    # Manually update the event
+    server.events['venid'][0].event_descriptor.modification_number = 1
+    server.events_updated['venid'] = True
+
+    await asyncio.sleep(1)
     assert ("You should implement your own on_update_event handler. This handler receives "
             "an Event dict and should return either 'optIn' or 'optOut' based on your "
             "choice. Will re-use the previous opt status for this event_id for now") in [record.msg for record in caplog.records]

+ 106 - 105
test/test_queues.py → test/test_event_distribution.py

@@ -13,8 +13,8 @@ def on_create_party_registration(registration_info):
 async def on_event(event):
     return 'optIn'
 
-async def on_event_opt_in(event, future):
-    if future.done() is False:
+async def on_event_opt_in(event, future=None):
+    if future and future.done() is False:
         future.set_result(event)
     return 'optIn'
 
@@ -70,107 +70,6 @@ async def test_internal_message_queue():
     await client.stop()
     await server.stop()
 
-
-@pytest.mark.asyncio
-async def test_event_status_opt_in():
-    loop = asyncio.get_event_loop()
-    client = OpenADRClient(ven_name='myven',
-                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
-    distribute_event_future = loop.create_future()
-    event_update_futures = [loop.create_future() for i in range(2)]
-    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
-    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
-
-    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
-    server.add_handler('on_create_party_registration', on_create_party_registration)
-
-    event_callback_future = loop.create_future()
-    event_id = server.add_event(ven_id='ven123',
-                                signal_name='simple',
-                                signal_type='level',
-                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=2),
-                                            'duration': datetime.timedelta(seconds=2),
-                                            'signal_payload': 1}],
-                                callback=partial(event_callback, future=event_callback_future))
-
-    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
-    await server.run_async()
-    #await asyncio.sleep(0.5)
-    await client.run()
-
-    await event_callback_future
-
-    print("Waiting for event future 1")
-    result = await distribute_event_future
-    assert result['event_descriptor']['event_status'] == 'far'
-    assert len(client.responded_events) == 1
-
-    print("Watiting for event future 2")
-    result = await event_update_futures[0]
-    assert result['event_descriptor']['event_status'] == 'active'
-    assert len(client.responded_events) == 1
-
-    print("Watiting for event future 3")
-    result = await event_update_futures[1]
-    assert result['event_descriptor']['event_status'] == 'completed'
-    assert len(client.responded_events) == 0
-
-    await client.stop()
-    await server.stop()
-    #await asyncio.sleep(0)
-
-@pytest.mark.asyncio
-async def test_event_status_opt_in_with_ramp_up():
-    loop = asyncio.get_event_loop()
-    client = OpenADRClient(ven_name='myven',
-                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
-    distribute_event_future = loop.create_future()
-    event_update_futures = [loop.create_future() for i in range(3)]
-    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
-    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
-
-    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
-    server.add_handler('on_create_party_registration', on_create_party_registration)
-
-    event_callback_future = loop.create_future()
-    event_id = server.add_event(ven_id='ven123',
-                                signal_name='simple',
-                                signal_type='level',
-                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=4),
-                                            'duration': datetime.timedelta(seconds=2),
-                                            'signal_payload': 1}],
-                                ramp_up_period=datetime.timedelta(seconds=2),
-                                callback=partial(event_callback, future=event_callback_future))
-
-    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
-    await server.run_async()
-    #await asyncio.sleep(0.5)
-    await client.run()
-
-    await event_callback_future
-
-    print("Waiting for event future 1")
-    result = await distribute_event_future
-    assert result['event_descriptor']['event_status'] == 'far'
-
-    print("Watiting for event future 2")
-    result = await event_update_futures[0]
-    assert result['event_descriptor']['event_status'] == 'near'
-
-    print("Watiting for event future 3")
-    result = await event_update_futures[1]
-    assert result['event_descriptor']['event_status'] == 'active'
-
-    print("Watiting for event future 4")
-    result = await event_update_futures[2]
-    assert result['event_descriptor']['event_status'] == 'completed'
-    #await asyncio.sleep(0.5)
-
-    await client.stop()
-    await server.stop()
-    #await asyncio.sleep(1)
-
-
 @pytest.mark.asyncio
 async def test_request_event():
     loop = asyncio.get_event_loop()
@@ -189,13 +88,13 @@ async def test_request_event():
                                 ramp_up_period=datetime.timedelta(seconds=2),
                                 callback=partial(event_callback))
 
-    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    assert server.events['ven123'][0].event_descriptor.event_status == 'far'
     await server.run_async()
     await client.create_party_registration()
     message_type, message_payload = await client.request_event()
     assert message_type == 'oadrDistributeEvent'
     message_type, message_payload = await client.request_event()
-    assert message_type == 'oadrResponse'
+    assert message_type == 'oadrDistributeEvent'
     await client.stop()
     await server.stop()
 
@@ -283,3 +182,105 @@ async def test_create_event_with_future_as_callback():
     assert result == 'optIn'
     await client.stop()
     await server.stop()
+
+@pytest.mark.asyncio
+async def test_multiple_events_in_queue():
+    now = datetime.datetime.now(datetime.timezone.utc)
+    server = OpenADRServer(vtn_id='myvtn')
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    loop = asyncio.get_event_loop()
+    event_1_callback_future = loop.create_future()
+    event_2_callback_future = loop.create_future()
+    server.add_event(ven_id='ven123',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[objects.Interval(dtstart=now,
+                                                 duration=datetime.timedelta(seconds=1),
+                                                 signal_payload=1)],
+                     callback=event_1_callback_future)
+
+    await server.run()
+
+    on_event_future = loop.create_future()
+    client = OpenADRClient(ven_name='ven123',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    await client.create_party_registration()
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrDistributeEvent'
+    events = response_payload['events']
+    assert len(events) == 1
+    event_id = events[0]['event_descriptor']['event_id']
+    request_id = response_payload['request_id']
+    await client.created_event(request_id=request_id,
+                               event_id=event_id,
+                               opt_type='optIn',
+                               modification_number=0)
+
+    server.add_event(ven_id='ven123',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[objects.Interval(dtstart=now + datetime.timedelta(seconds=1),
+                                                 duration=datetime.timedelta(seconds=1),
+                                                 signal_payload=1)],
+                     callback=event_2_callback_future)
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrDistributeEvent'
+    events = response_payload['events']
+
+    # Assert that we still have two events in the response
+    assert len(events) == 2
+
+    # Wait one second and retrieve the events again
+    await asyncio.sleep(1)
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrDistributeEvent'
+    events = response_payload['events']
+    assert len(events) == 2
+    assert events[1]['event_descriptor']['event_status'] == 'completed'
+
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrDistributeEvent'
+    events = response_payload['events']
+    assert len(events) == 1
+    await asyncio.sleep(1)
+
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrDistributeEvent'
+
+    response_type, response_payload = await client.request_event()
+    assert response_type == 'oadrResponse'
+
+    await server.stop()
+
+@pytest.mark.asyncio
+async def test_client_event_cleanup():
+    now = datetime.datetime.now(datetime.timezone.utc)
+    server = OpenADRServer(vtn_id='myvtn')
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    loop = asyncio.get_event_loop()
+    event_1_callback_future = loop.create_future()
+    event_2_callback_future = loop.create_future()
+    server.add_event(ven_id='ven123',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[objects.Interval(dtstart=now,
+                                                 duration=datetime.timedelta(seconds=1),
+                                                 signal_payload=1)],
+                     callback=event_1_callback_future)
+    await server.run()
+
+    client = OpenADRClient(ven_name='ven123',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    client.add_handler('on_event', on_event_opt_in)
+    await client.run()
+    await asyncio.sleep(0.5)
+    assert len(client.received_events) == 1
+
+    await asyncio.sleep(0.5)
+    await client._event_cleanup()
+    assert len(client.received_events) == 0
+
+    await server.stop()
+    await client.stop()

+ 142 - 3
test/test_utils.py

@@ -1,5 +1,5 @@
 from openleadr import utils, objects
-from dataclasses import dataclass
+from dataclasses import dataclass, asdict
 import pytest
 from datetime import datetime, timezone, timedelta
 from collections import deque
@@ -33,6 +33,14 @@ def test_setmember():
     utils.setmember(obj, 'a', 10)
     assert utils.getmember(obj, 'a') == 10
 
+def test_setmember_nested():
+    dc_parent = dc()
+    dc_parent.a = dc()
+
+    assert utils.getmember(utils.getmember(dc_parent, 'a'), 'a') == 2
+    utils.setmember(utils.getmember(dc_parent, 'a'), 'a', 3)
+    assert dc_parent.a.a == 3
+
 @pytest.mark.asyncio
 async def test_delayed_call_with_func():
     async def myfunc():
@@ -64,7 +72,7 @@ def test_determine_event_status_active():
 def test_determine_event_status_near():
     active_period = {'dtstart': datetime.now(timezone.utc) + timedelta(seconds=3),
                      'duration': timedelta(seconds=5),
-                     'ramp_up_duration': timedelta(seconds=5)}
+                     'ramp_up_period': timedelta(seconds=5)}
     assert utils.determine_event_status(active_period) == 'near'
 
 def test_determine_event_status_far():
@@ -75,7 +83,7 @@ def test_determine_event_status_far():
 def test_determine_event_status_far_with_ramp_up():
     active_period = {'dtstart': datetime.now(timezone.utc) + timedelta(seconds=10),
                      'duration': timedelta(seconds=5),
-                     'ramp_up_duration': timedelta(seconds=5)}
+                     'ramp_up_period': timedelta(seconds=5)}
     assert utils.determine_event_status(active_period) == 'far'
 
 def test_get_active_period_from_intervals():
@@ -290,3 +298,134 @@ def test_parse_datetime():
     assert utils.parse_datetime("2020-12-15T11:29:34.123456Z") == datetime(2020, 12, 15, 11, 29, 34, 123456, tzinfo=timezone.utc)
     assert utils.parse_datetime("2020-12-15T11:29:34.123Z") == datetime(2020, 12, 15, 11, 29, 34, 123000, tzinfo=timezone.utc)
     assert utils.parse_datetime("2020-12-15T11:29:34.123456789Z") == datetime(2020, 12, 15, 11, 29, 34, 123456, tzinfo=timezone.utc)
+
+@pytest.mark.asyncio
+async def test_await_if_required():
+    def normal_func():
+        return 123
+
+    async def coro_func():
+        return 456
+
+    result = await utils.await_if_required(normal_func())
+    assert result == 123
+
+    result = await utils.await_if_required(coro_func())
+    assert result == 456
+
+    result = await utils.await_if_required(None)
+    assert result == None
+
+@pytest.mark.asyncio
+async def test_gather_if_required():
+    def normal_func():
+        return 123
+
+    async def coro_func():
+        return 456
+
+    raw_results = [normal_func(), normal_func(), normal_func()]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [123, 123, 123]
+
+    raw_results = [coro_func(), coro_func(), coro_func()]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [456, 456, 456]
+
+    raw_results = [coro_func(), normal_func(), None]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [456, 123, None]
+
+    raw_results = []
+    results = await utils.gather_if_required(raw_results)
+    assert results == []
+
+def test_order_events():
+    now = datetime.now(timezone.utc)
+    event_1_active_high_prio = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',
+                                                                     modification_number=0,
+                                                                     created_date_time=now,
+                                                                     event_status='far',
+                                                                     priority=1,
+                                                                     market_context='http://context01'),
+                                             active_period=objects.ActivePeriod(dtstart=now - timedelta(minutes=5),
+                                                                                duration=timedelta(minutes=10)),
+                                             event_signals=[objects.EventSignal(intervals=[objects.Interval(dtstart=now,
+                                                                                                            duration=timedelta(minutes=10),
+                                                                                                            signal_payload=1)],
+                                                                                signal_name='simple',
+                                                                                signal_type='level',
+                                                                                signal_id='signal001')],
+                                             targets=[{'ven_id': 'ven001'}])
+
+    event_2_active_low_prio = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',
+                                                                     modification_number=0,
+                                                                     created_date_time=now,
+                                                                     event_status='far',
+                                                                     priority=2,
+                                                                     market_context='http://context01'),
+                                            active_period=objects.ActivePeriod(dtstart=now - timedelta(minutes=5),
+                                                                               duration=timedelta(minutes=10)),
+                                            event_signals=[objects.EventSignal(intervals=[objects.Interval(dtstart=now,
+                                                                                                           duration=timedelta(minutes=10),
+                                                                                                           signal_payload=1)],
+                                                                               signal_name='simple',
+                                                                               signal_type='level',
+                                                                               signal_id='signal001')],
+                                            targets=[{'ven_id': 'ven001'}])
+
+    event_3_active_no_prio = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',
+                                                                     modification_number=0,
+                                                                     created_date_time=now,
+                                                                     event_status='far',
+                                                                     market_context='http://context01'),
+                                            active_period=objects.ActivePeriod(dtstart=now - timedelta(minutes=5),
+                                                                               duration=timedelta(minutes=10)),
+                                            event_signals=[objects.EventSignal(intervals=[objects.Interval(dtstart=now,
+                                                                                                           duration=timedelta(minutes=10),
+                                                                                                           signal_payload=1)],
+                                                                               signal_name='simple',
+                                                                               signal_type='level',
+                                                                               signal_id='signal001')],
+                                            targets=[{'ven_id': 'ven001'}])
+
+    event_4_far_early = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',
+                                                                     modification_number=0,
+                                                                     created_date_time=now,
+                                                                     event_status='far',
+                                                                     market_context='http://context01'),
+                                     active_period=objects.ActivePeriod(dtstart=now + timedelta(minutes=5),
+                                                                        duration=timedelta(minutes=10)),
+                                     event_signals=[objects.EventSignal(intervals=[objects.Interval(dtstart=now,
+                                                                                                    duration=timedelta(minutes=10),
+                                                                                                    signal_payload=1)],
+                                                                        signal_name='simple',
+                                                                        signal_type='level',
+                                                                        signal_id='signal001')],
+                                     targets=[{'ven_id': 'ven001'}])
+
+    event_5_far_later = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',
+                                                                     modification_number=0,
+                                                                     created_date_time=now,
+                                                                     event_status='far',
+                                                                     market_context='http://context01'),
+                                     active_period=objects.ActivePeriod(dtstart=now + timedelta(minutes=10),
+                                                                        duration=timedelta(minutes=10)),
+                                     event_signals=[objects.EventSignal(intervals=[objects.Interval(dtstart=now,
+                                                                                                    duration=timedelta(minutes=10),
+                                                                                                    signal_payload=1)],
+                                                                        signal_name='simple',
+                                                                        signal_type='level',
+                                                                        signal_id='signal001')],
+                                     targets=[{'ven_id': 'ven001'}])
+
+    events = [event_5_far_later, event_4_far_early, event_3_active_no_prio, event_2_active_low_prio, event_1_active_high_prio]
+    ordered_events = utils.order_events(events)
+    assert ordered_events == [event_1_active_high_prio, event_2_active_low_prio, event_3_active_no_prio, event_4_far_early, event_5_far_later]
+
+    ordered_events = utils.order_events(event_1_active_high_prio)
+    assert ordered_events == [event_1_active_high_prio]
+
+    event_1_as_dict = asdict(event_1_active_high_prio)
+    ordered_events = utils.order_events(event_1_as_dict)
+    assert ordered_events == [event_1_as_dict]