Kaynağa Gözat

Add support for automatic Event status updates

This adds support for automatically updating the Event status from far to near to active to completed, and it implements the on_update_event handling on the client side.

It also adds support for the ramp_up_duration parameter on the server.add_event method.

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 yıl önce
ebeveyn
işleme
535f7e55b6

+ 29 - 6
openleadr/client.py

@@ -82,6 +82,7 @@ class OpenADRClient:
         self.scheduler = AsyncIOScheduler()
         self.client_session = None
         self.report_queue_task = None
+        self.responded_events = {}              # Holds the events that we already saw.
 
         self.cert_path = cert
         self.key_path = key
@@ -654,6 +655,16 @@ class OpenADRClient:
                        "choice. Will opt out of the event for now.")
         return 'optOut'
 
+    async def on_update_event(self, event):
+        """
+        Placeholder for the on_update_event handler.
+        """
+        logger.warning("You should implement your own on_update_event handler. This handler receives "
+                       "an Event dict and should return either 'optIn' or 'optOut' based on your "
+                       "choice. Will re-use the previous opt status for this event_id for now")
+        if event['event_descriptor']['event_id'] in self.events:
+            return self.responded_events['event_id']
+
     ###########################################################################
     #                                                                         #
     #                                  LOW LEVEL                              #
@@ -703,12 +714,24 @@ class OpenADRClient:
         logger.debug("The VEN received an event")
         events = message['events']
         try:
-            results = [self.on_event(event) for event in message['events']]
-            if asyncio.iscoroutine(results[0]):
-                results = await asyncio.gather(*results, return_exceptions=False)
+            results = []
+            for event in message['events']:
+                event_id = event['event_descriptor']['event_id']
+                event_status = event['event_descriptor']['event_status']
+                if event_id in self.responded_events:
+                    result = self.on_update_event(event)
+                else:
+                    result = self.on_event(event)
+                if asyncio.iscoroutine(result):
+                    result = await result
+                results.append(result)
+                if event_status == 'completed':
+                    self.responded_events.pop(event_id)
+                else:
+                    self.responded_events[event_id] = result
             for i, result in enumerate(results):
                 if result not in ('optIn', 'optOut'):
-                    logger.error("Your on_event handler must return 'optIn' or 'optOut'; "
+                    logger.error("Your on_event or on_update_event handler must return 'optIn' or 'optOut'; "
                                  f"you supplied {result}. Please fix your on_event handler.")
                     results[i] = 'optOut'
         except Exception as err:
@@ -716,7 +739,6 @@ class OpenADRClient:
                          f"The error was {err.__class__.__name__}: {str(err)}")
             results = ['optOut'] * len(events)
 
-        print("Done executing on_event for events")
         if len(events) == 1:
             logger.debug(f"Now responding with {results[0]}")
         else:
@@ -735,7 +757,8 @@ class OpenADRClient:
                     'request_id': message['request_id']}
         message = self._create_message('oadrCreatedEvent',
                                        response=response,
-                                       event_responses=event_responses)
+                                       event_responses=event_responses,
+                                       ven_id=self.ven_id)
         service = 'EiEvent'
         response_type, response_payload = await self._perform_request(service, message)
         logger.info(response_type, response_payload)

+ 0 - 5
openleadr/messaging.py

@@ -109,9 +109,7 @@ def validate_xml_signature(xml_tree, cert_fingerprint=None):
 
 async def authenticate_message(request, message_tree, message_payload, fingerprint_lookup):
     if request.secure and 'ven_id' in message_payload:
-        print("Getting cert fingerprint from request")
         connection_fingerprint = utils.get_cert_fingerprint_from_request(request)
-        print("Checking cert fingerprint")
         if connection_fingerprint is None:
             msg = ("Your request must use a client side SSL certificate, of which the "
                    "fingerprint must match the fingerprint that you have given to this VTN")
@@ -133,13 +131,11 @@ async def authenticate_message(request, message_tree, message_payload, fingerpri
                    "following fingerprint to make this request:")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Checking connection fingerprint")
         if connection_fingerprint != expected_fingerprint:
             msg = (f"The fingerprint of your HTTPS certificate {connection_fingerprint} "
                    f"does not match the expected fingerprint {expected_fingerprint}")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Checking message fingerprint")
         message_cert = utils.extract_pem_cert(message_tree)
         message_fingerprint = utils.certificate_fingerprint(message_cert)
         if message_fingerprint != expected_fingerprint:
@@ -149,7 +145,6 @@ async def authenticate_message(request, message_tree, message_payload, fingerpri
                    "certificate to sign your messages.")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Validating XML signature")
         try:
             validate_xml_signature(message_tree)
         except ValueError:

+ 9 - 7
openleadr/objects.py

@@ -17,7 +17,7 @@
 from dataclasses import dataclass, field, asdict, is_dataclass
 from typing import List, Dict
 from datetime import datetime, timezone, timedelta
-from openleadr.utils import group_targets_by_type, ungroup_targets_by_type
+from openleadr import utils
 
 
 @dataclass
@@ -144,12 +144,12 @@ class EventSignal:
             return
         elif self.targets_by_type is None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            self.targets_by_type = group_targets_by_type(list_of_targets)
+            self.targets_by_type = utils.group_targets_by_type(list_of_targets)
         elif self.targets is None:
-            self.targets = [Target(**target) for target in ungroup_targets_by_type(self.targets_by_type)]
+            self.targets = [Target(**target) for target in utils.ungroup_targets_by_type(self.targets_by_type)]
         elif self.targets is not None and self.targets_by_type is not None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            if group_targets_by_type(list_of_targets) != self.targets_by_type:
+            if utils.group_targets_by_type(list_of_targets) != self.targets_by_type:
                 raise ValueError("You assigned both 'targets' and 'targets_by_type' in your event, "
                                  "but the two were not consistent with each other. "
                                  f"You supplied 'targets' = {self.targets} and "
@@ -178,16 +178,18 @@ class Event:
             raise ValueError("You must supply either 'targets' or 'targets_by_type'.")
         elif self.targets_by_type is None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            self.targets_by_type = group_targets_by_type(list_of_targets)
+            self.targets_by_type = utils.group_targets_by_type(list_of_targets)
         elif self.targets is None:
-            self.targets = [Target(**target) for target in ungroup_targets_by_type(self.targets_by_type)]
+            self.targets = [Target(**target) for target in utils.ungroup_targets_by_type(self.targets_by_type)]
         elif self.targets is not None and self.targets_by_type is not None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            if group_targets_by_type(list_of_targets) != self.targets_by_type:
+            if utils.group_targets_by_type(list_of_targets) != self.targets_by_type:
                 raise ValueError("You assigned both 'targets' and 'targets_by_type' in your event, "
                                  "but the two were not consistent with each other. "
                                  f"You supplied 'targets' = {self.targets} and "
                                  f"'targets_by_type' = {self.targets_by_type}")
+        # Set the event status
+        self.event_descriptor.event_status = utils.determine_event_status(self.active_period)
 
 
 @dataclass

+ 24 - 4
openleadr/server.py

@@ -151,7 +151,8 @@ class OpenADRServer:
         await self.app_runner.cleanup()
 
     def add_event(self, ven_id, signal_name, signal_type, intervals, callback, targets=None,
-                  targets_by_type=None, target=None, market_context="oadr://unknown.context"):
+                  targets_by_type=None, target=None, market_context="oadr://unknown.context",
+                  notification_period=None, ramp_up_period=None, recovery_period=None):
         """
         Convenience method to add an event with a single signal.
         :param str ven_id: The ven_id to whom this event must be delivered.
@@ -163,6 +164,9 @@ class OpenADRServer:
         :param target: A single target for this event.
         :param dict targets_by_type: A dict of targets, grouped by type.
         :param str market_context: A URI for the DR program that this event belongs to.
+        :param timedelta notification_period: The Notification period for the Event's Active Period.
+        :param timedelta ramp_up_period: The Ramp Up period for the Event's Active Period.
+        :param timedelta recovery_period: The Recovery period for the Event's Active Period.
 
         If you don't provide a target using any of the three arguments, the target will be set to the given ven_id.
         """
@@ -190,20 +194,36 @@ class OpenADRServer:
         event_descriptor = objects.EventDescriptor(event_id=event_id,
                                                    modification_number=0,
                                                    market_context=market_context,
-                                                   event_status="near",
+                                                   event_status="far",
                                                    created_date_time=datetime.now(timezone.utc))
         event_signal = objects.EventSignal(intervals=intervals,
                                            signal_name=signal_name,
                                            signal_type=signal_type,
                                            signal_id=utils.generate_id(),
                                            targets=targets)
-        event = objects.Event(event_descriptor=event_descriptor,
+        # Make sure the intervals carry timezone-aware timestamps
+        for interval in intervals:
+            if utils.getmember(interval, 'dtstart').tzinfo is None:
+                utils.setmember(interval, 'dtstart',
+                                utils.getmember(interval, 'dtstart').astimezone(timezone.utc))
+                logger.warning("You supplied a naive datetime object to your interval's dtstart. "
+                               "This will be interpreted as a timestamp in your local timezone "
+                               "and then converted to UTC before sending. Please supply timezone-"
+                               "aware timestamps like datetime.datetime.new(timezone.utc) or "
+                               "datetime.datetime(..., tzinfo=datetime.timezone.utc)")
+        active_period = utils.get_active_period_from_intervals(intervals, False)
+        active_period.ramp_up_period = ramp_up_period
+        active_period.notification_period = notification_period
+        active_period.recovery_period = recovery_period
+        event = objects.Event(active_period=active_period,
+                              event_descriptor=event_descriptor,
                               event_signals=[event_signal],
                               targets=targets)
         if ven_id not in self.message_queues:
             self.message_queues[ven_id] = asyncio.Queue()
         self.message_queues[ven_id].put_nowait(event)
-        self.services['event_service'].pending_events[event_id] = callback
+        self.services['event_service'].pending_events[event_id] = (event, callback)
+        return event_id
 
     def add_raw_event(self, ven_id, event):
         """

+ 46 - 9
openleadr/service/event_service.py

@@ -15,9 +15,11 @@
 # limitations under the License.
 
 from . import service, handler, VTNService
-from asyncio import iscoroutine
-from .. import objects
+import asyncio
+from openleadr import objects, utils, enums
 import logging
+from datetime import datetime, timezone
+from functools import partial
 logger = logging.getLogger('openleadr')
 
 
@@ -29,6 +31,7 @@ class EventService(VTNService):
         self.polling_method = polling_method
         self.message_queues = message_queues
         self.pending_events = {}        # Holds the event callbacks
+        self.running_events = {}        # Holds the event callbacks for accepted events
 
     @handler('oadrRequestEvent')
     async def request_event(self, payload):
@@ -36,7 +39,7 @@ class EventService(VTNService):
         The VEN requests us to send any events we have.
         """
         result = self.on_request_event(payload['ven_id'])
-        if iscoroutine(result):
+        if asyncio.iscoroutine(result):
             result = await result
         if result is None:
             return 'oadrDistributeEvent', {'events': []}
@@ -64,19 +67,44 @@ class EventService(VTNService):
         """
         The VEN informs us that they created an EiEvent.
         """
+        loop = asyncio.get_event_loop()
+        ven_id = payload['ven_id']
         if self.polling_method == 'internal':
             for event_response in payload['event_responses']:
+                event_id = event_response['event_id']
+                opt_type = event_response['opt_type']
                 if event_response['event_id'] in self.pending_events:
-                    ven_id = payload['ven_id']
-                    event_id = event_response['event_id']
-                    opt_type = event_response['opt_type']
-                    callback = self.pending_events.pop(event_id)
+                    event, callback = self.pending_events.pop(event_id)
                     result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
-                    if iscoroutine(result):
+                    if asyncio.iscoroutine(result):
+                        result = await result
+                    if opt_type == 'optIn':
+                        self.running_events[event_id] = (event, callback)
+                        now = datetime.now(timezone.utc)
+                        active_period = event.active_period
+                        # Schedule status update to 'near' if applicable
+                        if active_period.ramp_up_period is not None and event.event_descriptor.event_status == 'far':
+                            ramp_up_start_delay = (active_period.dtstart - active_period.ramp_up_period) - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'near')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=ramp_up_start_delay))
+                        # Schedule status update to 'active'
+                        if event.event_descriptor.event_status in ('near', 'far'):
+                            active_start_delay = active_period.dtstart - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'active')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_start_delay))
+                        # Schedule status update to 'completed'
+                        if event.event_descriptor.event_status in ('near', 'far', 'active'):
+                            active_end_delay = active_period.dtstart + active_period.duration - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'completed')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_end_delay))
+                elif event_response['event_id'] in self.running_events:
+                    event, callback = self.running_events.pop(event_id)
+                    result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
+                    if asyncio.iscoroutine(result):
                         result = await result
         else:
             result = self.on_created_event(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
-            if iscoroutine(result):
+            if asyncio.iscoroutine(result):
                 result = await(result)
         return 'oadrResponse', {}
 
@@ -89,3 +117,12 @@ class EventService(VTNService):
                        "handler will receive a ven_id, event_id and opt_status. "
                        "You don't need to return anything from this handler.")
         return None
+
+    def _update_event_status(self, ven_id, event, event_status):
+        """
+        Update the event to the given Status.
+        """
+        event.event_descriptor.event_status = event_status
+        if event_status == enums.EVENT_STATUS.CANCELLED:
+            event.event_descriptor.modification_number += 1
+        self.message_queues[ven_id].put_nowait(event)

+ 6 - 6
openleadr/templates/parts/eiActivePeriod.xml

@@ -13,19 +13,19 @@
             </tolerate>
         </tolerance>
         {% endif %}
-        {% if event.active_period.notification %}
+        {% if event.active_period.notification_period %}
         <ei:x-eiNotification>
-            <duration>{{ event.active_period.notification|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.notification_period|timedeltaformat }}</duration>
         </ei:x-eiNotification>
         {% endif %}
-        {% if event.active_period.ramp_up %}
+        {% if event.active_period.ramp_up_period %}
         <ei:x-eiRampUp>
-            <duration>{{ event.active_period.ramp_up|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.ramp_up_period|timedeltaformat }}</duration>
         </ei:x-eiRampUp>
         {% endif %}
-        {% if event.active_period.recovery %}
+        {% if event.active_period.recovery_period %}
         <ei:x-eiRecovery>
-            <duration>{{ event.active_period.recovery|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.recovery_period|timedeltaformat }}</duration>
         </ei:x-eiRecovery>
         {% endif %}
     </properties>

+ 79 - 1
openleadr/utils.py

@@ -17,6 +17,7 @@
 from datetime import datetime, timedelta, timezone
 from dataclasses import is_dataclass, asdict
 from collections import OrderedDict
+import asyncio
 import itertools
 import re
 import ssl
@@ -298,7 +299,7 @@ def parse_datetime(value):
         year, month, day, hour, minute, second, micro = (int(value) for value in matches.groups())
         return datetime(year, month, day, hour, minute, second, micro, tzinfo=timezone.utc)
     else:
-        print(f"{value} did not match format")
+        logger.warning(f"parse_datetime: {value} did not match format")
         return value
 
 
@@ -599,3 +600,80 @@ def validate_report_measurement_dict(measurement):
             raise ValueError("A 'power' related measurement must contain a "
                              "'power_attributes' section that contains the following "
                              "keys: 'voltage' (int), 'ac' (boolean), 'hertz' (int)")
+
+
+def get_active_period_from_intervals(intervals, as_dict=True):
+    if is_dataclass(intervals[0]):
+        intervals = [asdict(i) for i in intervals]
+    period_start = min([i['dtstart'] for i in intervals])
+    period_duration = max([i['dtstart'] + i['duration'] - period_start for i in intervals])
+    if as_dict:
+        return {'dtstart': period_start,
+                'duration': period_duration}
+    else:
+        from openleadr.objects import ActivePeriod
+        return ActivePeriod(dtstart=period_start, duration=period_duration)
+
+
+def determine_event_status(active_period):
+    if is_dataclass(active_period):
+        active_period = asdict(active_period)
+    now = datetime.now(timezone.utc)
+    if active_period['dtstart'].tzinfo is None:
+        active_period['dtstart'] = active_period['dtstart'].astimezone(timezone.utc)
+    active_period_start = active_period['dtstart']
+    active_period_end = active_period['dtstart'] + active_period['duration']
+    if now >= active_period_end:
+        return 'completed'
+    if now >= active_period_start:
+        return 'active'
+    if active_period.get('ramp_up_duration') is not None:
+        ramp_up_start = active_period_start - active_period['ramp_up_duration']
+        if now >= ramp_up_start:
+            return 'near'
+    return 'far'
+
+
+async def delayed_call(func, delay):
+    if isinstance(delay, timedelta):
+        delay = delay.total_seconds()
+    await asyncio.sleep(delay)
+    if asyncio.iscoroutinefunction(func):
+        await func()
+    elif asyncio.iscoroutine(func):
+        await func
+    else:
+        func()
+
+
+def hasmember(obj, member):
+    """
+    Check if a dict or dataclass has the given member
+    """
+    if is_dataclass(obj):
+        if hasattr(obj, member):
+            return True
+    else:
+        if member in obj:
+            return True
+    return False
+
+
+def getmember(obj, member):
+    """
+    Get a member from a dict or dataclass
+    """
+    if is_dataclass(obj):
+        return getattr(obj, member)
+    else:
+        return obj[member]
+
+
+def setmember(obj, member, value):
+    """
+    Set a member of a dict of dataclass
+    """
+    if is_dataclass(obj):
+        setattr(obj, member, value)
+    else:
+        obj[member] = value

+ 72 - 24
test/integration_tests/test_event_warnings_errors.py

@@ -1,15 +1,16 @@
-from openleadr import OpenADRClient, OpenADRServer, enable_default_logging
+from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils
 import pytest
 from functools import partial
 import asyncio
-from datetime import datetime, timedelta
+from datetime import datetime, timedelta, timezone
 import logging
 
 async def on_create_party_registration(ven_name):
     return 'venid', 'regid'
 
-async def on_event_accepted(ven_id, event_id, opt_type, future):
-    future.set_result(opt_type)
+async def on_event_accepted(ven_id, event_id, opt_type, future=None):
+    if future and future.done() is False:
+        future.set_result(opt_type)
 
 async def good_on_event(event):
     return 'optIn'
@@ -41,8 +42,8 @@ async def test_client_no_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -55,7 +56,7 @@ async def test_client_no_event_handler(caplog):
             "choice. Will opt out of the event for now.") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_faulty_event_handler(caplog):
@@ -79,8 +80,8 @@ async def test_client_faulty_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -88,11 +89,11 @@ async def test_client_faulty_event_handler(caplog):
     print("Waiting for a response to the event")
     result = await event_confirm_future
     assert result == 'optOut'
-    assert ("Your on_event handler must return 'optIn' or 'optOut'; "
+    assert ("Your on_event or on_update_event handler must return 'optIn' or 'optOut'; "
            f"you supplied {None}. Please fix your on_event handler.") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_exception_event_handler(caplog):
@@ -116,8 +117,8 @@ async def test_client_exception_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -131,7 +132,7 @@ async def test_client_exception_event_handler(caplog):
            f"The error was {err.__class__.__name__}: {str(err)}") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_good_event_handler(caplog):
@@ -155,8 +156,8 @@ async def test_client_good_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -165,7 +166,6 @@ async def test_client_good_event_handler(caplog):
     result = await event_confirm_future
     assert result == 'optIn'
     assert len(caplog.records) == 0
-    await asyncio.sleep(1)
     await client.stop()
     await server.stop()
     await asyncio.sleep(1)
@@ -178,21 +178,69 @@ async def test_server_warning_conflicting_poll_methods(caplog):
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_handler('on_poll', print)
-    print("Running server")
-    await server.run_async()
-    await asyncio.sleep(0.1)
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
-                     callback=print)
+                     callback=on_event_accepted)
     assert ("You cannot use the add_event method after you assign your own on_poll "
             "handler. If you use your own on_poll handler, you are responsible for "
             "delivering events from that handler. If you want to use OpenLEADRs "
             "message queuing system, you should not assign an on_poll handler. "
             "Your Event will NOT be added.") in [record.msg for record in caplog.records]
+
+
+@pytest.mark.asyncio
+async def test_server_warning_naive_datetimes_in_event(caplog):
+    caplog.set_level(logging.WARNING)
+    enable_default_logging()
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
+    server.add_event(ven_id='venid',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'},
+                     callback=on_event_accepted)
+    assert ("You supplied a naive datetime object to your interval's dtstart. "
+            "This will be interpreted as a timestamp in your local timezone "
+            "and then converted to UTC before sending. Please supply timezone-"
+            "aware timestamps like datetime.datetime.new(timezone.utc) or "
+            "datetime.datetime(..., tzinfo=datetime.timezone.utc)") in [record.msg for record in caplog.records]
+
+
+@pytest.mark.asyncio
+async def test_client_warning_no_update_event_handler(caplog):
+    caplog.set_level(logging.WARNING)
+    enable_default_logging()
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+    server.add_event(ven_id='venid',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'},
+                     callback=on_event_accepted)
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    client.add_handler('on_event', good_on_event)
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+    await asyncio.sleep(2)
+    assert ("You should implement your own on_update_event handler. This handler receives "
+            "an Event dict and should return either 'optIn' or 'optOut' based on your "
+            "choice. Will re-use the previous opt status for this event_id for now") in [record.msg for record in caplog.records]
+    await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])

+ 83 - 0
test/test_client_misc.py

@@ -0,0 +1,83 @@
+import pytest
+from openleadr import OpenADRClient
+from openleadr import enums
+
+def test_trailing_slash_on_vtn_url():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost/')
+    assert client.vtn_url == 'http://localhost'
+
+def test_wrong_handler_supplied(caplog):
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    client.add_handler('non_existant', print)
+    assert ("'handler' must be either on_event or on_update_event") in [rec.message for rec in caplog.records]
+
+def test_invalid_report_name(caplog):
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          report_name='non_existant')
+        # assert (f"non_existant is not a valid report_name. Valid options are "
+        #         f"{', '.join(enums.REPORT_NAME.values)}",
+        #         " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_reading_type():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          reading_type='non_existant')
+            # assert (f"non_existant is not a valid reading_type. Valid options are "
+            # f"{', '.join(enums.READING_TYPE.values)}",
+            # " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_report_type():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          report_type='non_existant')
+        # assert (f"non_existant is not a valid report_type. Valid options are "
+        #         f"{', '.join(enums.REPORT_TYPE.values)}",
+        #         " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_data_collection_mode():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          data_collection_mode='non_existant')
+        # assert ("The data_collection_mode should be 'incremental' or 'full'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_scale():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          scale='non_existant')
+
+def test_add_report_without_specifier_id():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    client.add_report(callback=print,
+                      resource_id='myresource1',
+                      measurement='voltage')
+    client.add_report(callback=print,
+                      resource_id='myresource2',
+                      measurement='voltage')
+    assert len(client.reports) == 1
+
+async def wrong_sig(param1):
+    pass
+
+def test_add_report_with_invalid_callback_signature():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(TypeError):
+        client.add_report(callback=wrong_sig,
+                          data_collection_mode='full',
+                          resource_id='myresource1',
+                          measurement='voltage')

+ 10 - 14
test/test_failures.py

@@ -27,7 +27,7 @@ async def test_signature_error(start_server_with_signatures):
     client.on_event = _client_on_event
     await client.run()
     await asyncio.sleep(0)
-    await client.client_session.close()
+    await client.stop()
 
 
 ##########################################################################################
@@ -61,25 +61,21 @@ async def _client_on_report(report):
 
 @pytest.fixture
 async def start_server():
-    server = OpenADRServer(vtn_id=VTN_ID)
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_host='localhost',
+                           http_port=SERVER_PORT)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
-    runner = web.AppRunner(server.app)
-    await runner.setup()
-    site = web.TCPSite(runner, 'localhost', SERVER_PORT)
-    await site.start()
+    await server.run_async()
     yield
-    await runner.cleanup()
+    await server.stop()
 
 @pytest.fixture
 async def start_server_with_signatures():
-    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr')
+    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr',
+                           http_host='localhost', http_port=SERVER_PORT)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
-    runner = web.AppRunner(server.app)
-    await runner.setup()
-    site = web.TCPSite(runner, 'localhost', SERVER_PORT)
-    await site.start()
+    await server.run_async()
     yield
-    await runner.cleanup()
-
+    await server.stop()

+ 125 - 4
test/test_queues.py

@@ -13,8 +13,28 @@ def on_create_party_registration(registration_info):
 async def on_event(event):
     return 'optIn'
 
+async def on_event_opt_in(event, future):
+    if future.done() is False:
+        future.set_result(event)
+    return 'optIn'
+
+async def on_update_event(event, futures):
+    for future in futures:
+        if future.done() is False:
+            future.set_result(event)
+            break
+    return 'optIn'
+
+async def on_event_opt_out(event, futures):
+    for future in futures:
+        if future.done() is False:
+            future.set_result(event)
+            break
+    return 'optOut'
+
 async def event_callback(ven_id, event_id, opt_type, future):
-    future.set_result(opt_type)
+    if future.done() is False:
+        future.set_result(opt_type)
 
 @pytest.mark.asyncio
 async def test_internal_message_queue():
@@ -28,8 +48,8 @@ async def test_internal_message_queue():
     server.add_event(ven_id='ven123',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.datetime.now(),
-                                 'duration': datetime.timedelta(minutes=5),
+                     intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc),
+                                 'duration': datetime.timedelta(seconds=3),
                                  'signal_payload': 1}],
                      callback=partial(event_callback, future=event_callback_future))
 
@@ -46,5 +66,106 @@ async def test_internal_message_queue():
     message_type, message_payload = await asyncio.wait_for(client.poll(), 0.5)
     assert message_type == 'oadrResponse'
 
+    await asyncio.sleep(1)  # Wait for the event to be completed
+    await client.stop()
+    await server.stop()
+
+
+@pytest.mark.asyncio
+async def test_event_status_opt_in():
+    loop = asyncio.get_event_loop()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    distribute_event_future = loop.create_future()
+    event_update_futures = [loop.create_future() for i in range(2)]
+    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
+    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
+
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    event_callback_future = loop.create_future()
+    event_id = server.add_event(ven_id='ven123',
+                                signal_name='simple',
+                                signal_type='level',
+                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=2),
+                                            'duration': datetime.timedelta(seconds=2),
+                                            'signal_payload': 1}],
+                                callback=partial(event_callback, future=event_callback_future))
+
+    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+
+    await event_callback_future
+
+    print("Waiting for event future 1")
+    result = await distribute_event_future
+    assert result['event_descriptor']['event_status'] == 'far'
+    assert len(client.responded_events) == 1
+
+    print("Watiting for event future 2")
+    result = await event_update_futures[0]
+    assert result['event_descriptor']['event_status'] == 'active'
+    assert len(client.responded_events) == 1
+
+    print("Watiting for event future 3")
+    result = await event_update_futures[1]
+    assert result['event_descriptor']['event_status'] == 'completed'
+    assert len(client.responded_events) == 0
+
     await client.stop()
-    await server.stop()
+    await server.stop()
+    await asyncio.sleep(0)
+
+@pytest.mark.asyncio
+async def test_event_status_opt_in_with_ramp_up():
+    loop = asyncio.get_event_loop()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    distribute_event_future = loop.create_future()
+    event_update_futures = [loop.create_future() for i in range(3)]
+    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
+    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
+
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    event_callback_future = loop.create_future()
+    event_id = server.add_event(ven_id='ven123',
+                                signal_name='simple',
+                                signal_type='level',
+                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=4),
+                                            'duration': datetime.timedelta(seconds=2),
+                                            'signal_payload': 1}],
+                                ramp_up_period=datetime.timedelta(seconds=2),
+                                callback=partial(event_callback, future=event_callback_future))
+
+    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+
+    await event_callback_future
+
+    print("Waiting for event future 1")
+    result = await distribute_event_future
+    assert result['event_descriptor']['event_status'] == 'far'
+
+    print("Watiting for event future 2")
+    result = await event_update_futures[0]
+    assert result['event_descriptor']['event_status'] == 'near'
+
+    print("Watiting for event future 3")
+    result = await event_update_futures[1]
+    assert result['event_descriptor']['event_status'] == 'active'
+
+    print("Watiting for event future 4")
+    result = await event_update_futures[2]
+    assert result['event_descriptor']['event_status'] == 'completed'
+    await asyncio.sleep(0.5)
+
+    await client.stop()
+    await server.stop()
+    await asyncio.sleep(1)