浏览代码

Add support for automatic Event status updates

This adds support for automatically updating the Event status from far to near to active to completed, and it implements the on_update_event handling on the client side.

It also adds support for the ramp_up_duration parameter on the server.add_event method.

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 年之前
父节点
当前提交
535f7e55b6

+ 29 - 6
openleadr/client.py

@@ -82,6 +82,7 @@ class OpenADRClient:
         self.scheduler = AsyncIOScheduler()
         self.client_session = None
         self.report_queue_task = None
+        self.responded_events = {}              # Holds the events that we already saw.
 
         self.cert_path = cert
         self.key_path = key
@@ -654,6 +655,16 @@ class OpenADRClient:
                        "choice. Will opt out of the event for now.")
         return 'optOut'
 
+    async def on_update_event(self, event):
+        """
+        Placeholder for the on_update_event handler.
+        """
+        logger.warning("You should implement your own on_update_event handler. This handler receives "
+                       "an Event dict and should return either 'optIn' or 'optOut' based on your "
+                       "choice. Will re-use the previous opt status for this event_id for now")
+        if event['event_descriptor']['event_id'] in self.events:
+            return self.responded_events['event_id']
+
     ###########################################################################
     #                                                                         #
     #                                  LOW LEVEL                              #
@@ -703,12 +714,24 @@ class OpenADRClient:
         logger.debug("The VEN received an event")
         events = message['events']
         try:
-            results = [self.on_event(event) for event in message['events']]
-            if asyncio.iscoroutine(results[0]):
-                results = await asyncio.gather(*results, return_exceptions=False)
+            results = []
+            for event in message['events']:
+                event_id = event['event_descriptor']['event_id']
+                event_status = event['event_descriptor']['event_status']
+                if event_id in self.responded_events:
+                    result = self.on_update_event(event)
+                else:
+                    result = self.on_event(event)
+                if asyncio.iscoroutine(result):
+                    result = await result
+                results.append(result)
+                if event_status == 'completed':
+                    self.responded_events.pop(event_id)
+                else:
+                    self.responded_events[event_id] = result
             for i, result in enumerate(results):
                 if result not in ('optIn', 'optOut'):
-                    logger.error("Your on_event handler must return 'optIn' or 'optOut'; "
+                    logger.error("Your on_event or on_update_event handler must return 'optIn' or 'optOut'; "
                                  f"you supplied {result}. Please fix your on_event handler.")
                     results[i] = 'optOut'
         except Exception as err:
@@ -716,7 +739,6 @@ class OpenADRClient:
                          f"The error was {err.__class__.__name__}: {str(err)}")
             results = ['optOut'] * len(events)
 
-        print("Done executing on_event for events")
         if len(events) == 1:
             logger.debug(f"Now responding with {results[0]}")
         else:
@@ -735,7 +757,8 @@ class OpenADRClient:
                     'request_id': message['request_id']}
         message = self._create_message('oadrCreatedEvent',
                                        response=response,
-                                       event_responses=event_responses)
+                                       event_responses=event_responses,
+                                       ven_id=self.ven_id)
         service = 'EiEvent'
         response_type, response_payload = await self._perform_request(service, message)
         logger.info(response_type, response_payload)

+ 0 - 5
openleadr/messaging.py

@@ -109,9 +109,7 @@ def validate_xml_signature(xml_tree, cert_fingerprint=None):
 
 async def authenticate_message(request, message_tree, message_payload, fingerprint_lookup):
     if request.secure and 'ven_id' in message_payload:
-        print("Getting cert fingerprint from request")
         connection_fingerprint = utils.get_cert_fingerprint_from_request(request)
-        print("Checking cert fingerprint")
         if connection_fingerprint is None:
             msg = ("Your request must use a client side SSL certificate, of which the "
                    "fingerprint must match the fingerprint that you have given to this VTN")
@@ -133,13 +131,11 @@ async def authenticate_message(request, message_tree, message_payload, fingerpri
                    "following fingerprint to make this request:")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Checking connection fingerprint")
         if connection_fingerprint != expected_fingerprint:
             msg = (f"The fingerprint of your HTTPS certificate {connection_fingerprint} "
                    f"does not match the expected fingerprint {expected_fingerprint}")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Checking message fingerprint")
         message_cert = utils.extract_pem_cert(message_tree)
         message_fingerprint = utils.certificate_fingerprint(message_cert)
         if message_fingerprint != expected_fingerprint:
@@ -149,7 +145,6 @@ async def authenticate_message(request, message_tree, message_payload, fingerpri
                    "certificate to sign your messages.")
             raise errors.NotRegisteredOrAuthorizedError(msg)
 
-        print("Validating XML signature")
         try:
             validate_xml_signature(message_tree)
         except ValueError:

+ 9 - 7
openleadr/objects.py

@@ -17,7 +17,7 @@
 from dataclasses import dataclass, field, asdict, is_dataclass
 from typing import List, Dict
 from datetime import datetime, timezone, timedelta
-from openleadr.utils import group_targets_by_type, ungroup_targets_by_type
+from openleadr import utils
 
 
 @dataclass
@@ -144,12 +144,12 @@ class EventSignal:
             return
         elif self.targets_by_type is None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            self.targets_by_type = group_targets_by_type(list_of_targets)
+            self.targets_by_type = utils.group_targets_by_type(list_of_targets)
         elif self.targets is None:
-            self.targets = [Target(**target) for target in ungroup_targets_by_type(self.targets_by_type)]
+            self.targets = [Target(**target) for target in utils.ungroup_targets_by_type(self.targets_by_type)]
         elif self.targets is not None and self.targets_by_type is not None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            if group_targets_by_type(list_of_targets) != self.targets_by_type:
+            if utils.group_targets_by_type(list_of_targets) != self.targets_by_type:
                 raise ValueError("You assigned both 'targets' and 'targets_by_type' in your event, "
                                  "but the two were not consistent with each other. "
                                  f"You supplied 'targets' = {self.targets} and "
@@ -178,16 +178,18 @@ class Event:
             raise ValueError("You must supply either 'targets' or 'targets_by_type'.")
         elif self.targets_by_type is None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            self.targets_by_type = group_targets_by_type(list_of_targets)
+            self.targets_by_type = utils.group_targets_by_type(list_of_targets)
         elif self.targets is None:
-            self.targets = [Target(**target) for target in ungroup_targets_by_type(self.targets_by_type)]
+            self.targets = [Target(**target) for target in utils.ungroup_targets_by_type(self.targets_by_type)]
         elif self.targets is not None and self.targets_by_type is not None:
             list_of_targets = [asdict(target) if is_dataclass(target) else target for target in self.targets]
-            if group_targets_by_type(list_of_targets) != self.targets_by_type:
+            if utils.group_targets_by_type(list_of_targets) != self.targets_by_type:
                 raise ValueError("You assigned both 'targets' and 'targets_by_type' in your event, "
                                  "but the two were not consistent with each other. "
                                  f"You supplied 'targets' = {self.targets} and "
                                  f"'targets_by_type' = {self.targets_by_type}")
+        # Set the event status
+        self.event_descriptor.event_status = utils.determine_event_status(self.active_period)
 
 
 @dataclass

+ 24 - 4
openleadr/server.py

@@ -151,7 +151,8 @@ class OpenADRServer:
         await self.app_runner.cleanup()
 
     def add_event(self, ven_id, signal_name, signal_type, intervals, callback, targets=None,
-                  targets_by_type=None, target=None, market_context="oadr://unknown.context"):
+                  targets_by_type=None, target=None, market_context="oadr://unknown.context",
+                  notification_period=None, ramp_up_period=None, recovery_period=None):
         """
         Convenience method to add an event with a single signal.
         :param str ven_id: The ven_id to whom this event must be delivered.
@@ -163,6 +164,9 @@ class OpenADRServer:
         :param target: A single target for this event.
         :param dict targets_by_type: A dict of targets, grouped by type.
         :param str market_context: A URI for the DR program that this event belongs to.
+        :param timedelta notification_period: The Notification period for the Event's Active Period.
+        :param timedelta ramp_up_period: The Ramp Up period for the Event's Active Period.
+        :param timedelta recovery_period: The Recovery period for the Event's Active Period.
 
         If you don't provide a target using any of the three arguments, the target will be set to the given ven_id.
         """
@@ -190,20 +194,36 @@ class OpenADRServer:
         event_descriptor = objects.EventDescriptor(event_id=event_id,
                                                    modification_number=0,
                                                    market_context=market_context,
-                                                   event_status="near",
+                                                   event_status="far",
                                                    created_date_time=datetime.now(timezone.utc))
         event_signal = objects.EventSignal(intervals=intervals,
                                            signal_name=signal_name,
                                            signal_type=signal_type,
                                            signal_id=utils.generate_id(),
                                            targets=targets)
-        event = objects.Event(event_descriptor=event_descriptor,
+        # Make sure the intervals carry timezone-aware timestamps
+        for interval in intervals:
+            if utils.getmember(interval, 'dtstart').tzinfo is None:
+                utils.setmember(interval, 'dtstart',
+                                utils.getmember(interval, 'dtstart').astimezone(timezone.utc))
+                logger.warning("You supplied a naive datetime object to your interval's dtstart. "
+                               "This will be interpreted as a timestamp in your local timezone "
+                               "and then converted to UTC before sending. Please supply timezone-"
+                               "aware timestamps like datetime.datetime.new(timezone.utc) or "
+                               "datetime.datetime(..., tzinfo=datetime.timezone.utc)")
+        active_period = utils.get_active_period_from_intervals(intervals, False)
+        active_period.ramp_up_period = ramp_up_period
+        active_period.notification_period = notification_period
+        active_period.recovery_period = recovery_period
+        event = objects.Event(active_period=active_period,
+                              event_descriptor=event_descriptor,
                               event_signals=[event_signal],
                               targets=targets)
         if ven_id not in self.message_queues:
             self.message_queues[ven_id] = asyncio.Queue()
         self.message_queues[ven_id].put_nowait(event)
-        self.services['event_service'].pending_events[event_id] = callback
+        self.services['event_service'].pending_events[event_id] = (event, callback)
+        return event_id
 
     def add_raw_event(self, ven_id, event):
         """

+ 46 - 9
openleadr/service/event_service.py

@@ -15,9 +15,11 @@
 # limitations under the License.
 
 from . import service, handler, VTNService
-from asyncio import iscoroutine
-from .. import objects
+import asyncio
+from openleadr import objects, utils, enums
 import logging
+from datetime import datetime, timezone
+from functools import partial
 logger = logging.getLogger('openleadr')
 
 
@@ -29,6 +31,7 @@ class EventService(VTNService):
         self.polling_method = polling_method
         self.message_queues = message_queues
         self.pending_events = {}        # Holds the event callbacks
+        self.running_events = {}        # Holds the event callbacks for accepted events
 
     @handler('oadrRequestEvent')
     async def request_event(self, payload):
@@ -36,7 +39,7 @@ class EventService(VTNService):
         The VEN requests us to send any events we have.
         """
         result = self.on_request_event(payload['ven_id'])
-        if iscoroutine(result):
+        if asyncio.iscoroutine(result):
             result = await result
         if result is None:
             return 'oadrDistributeEvent', {'events': []}
@@ -64,19 +67,44 @@ class EventService(VTNService):
         """
         The VEN informs us that they created an EiEvent.
         """
+        loop = asyncio.get_event_loop()
+        ven_id = payload['ven_id']
         if self.polling_method == 'internal':
             for event_response in payload['event_responses']:
+                event_id = event_response['event_id']
+                opt_type = event_response['opt_type']
                 if event_response['event_id'] in self.pending_events:
-                    ven_id = payload['ven_id']
-                    event_id = event_response['event_id']
-                    opt_type = event_response['opt_type']
-                    callback = self.pending_events.pop(event_id)
+                    event, callback = self.pending_events.pop(event_id)
                     result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
-                    if iscoroutine(result):
+                    if asyncio.iscoroutine(result):
+                        result = await result
+                    if opt_type == 'optIn':
+                        self.running_events[event_id] = (event, callback)
+                        now = datetime.now(timezone.utc)
+                        active_period = event.active_period
+                        # Schedule status update to 'near' if applicable
+                        if active_period.ramp_up_period is not None and event.event_descriptor.event_status == 'far':
+                            ramp_up_start_delay = (active_period.dtstart - active_period.ramp_up_period) - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'near')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=ramp_up_start_delay))
+                        # Schedule status update to 'active'
+                        if event.event_descriptor.event_status in ('near', 'far'):
+                            active_start_delay = active_period.dtstart - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'active')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_start_delay))
+                        # Schedule status update to 'completed'
+                        if event.event_descriptor.event_status in ('near', 'far', 'active'):
+                            active_end_delay = active_period.dtstart + active_period.duration - now
+                            update_coro = partial(self._update_event_status, ven_id, event, 'completed')
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_end_delay))
+                elif event_response['event_id'] in self.running_events:
+                    event, callback = self.running_events.pop(event_id)
+                    result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
+                    if asyncio.iscoroutine(result):
                         result = await result
         else:
             result = self.on_created_event(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
-            if iscoroutine(result):
+            if asyncio.iscoroutine(result):
                 result = await(result)
         return 'oadrResponse', {}
 
@@ -89,3 +117,12 @@ class EventService(VTNService):
                        "handler will receive a ven_id, event_id and opt_status. "
                        "You don't need to return anything from this handler.")
         return None
+
+    def _update_event_status(self, ven_id, event, event_status):
+        """
+        Update the event to the given Status.
+        """
+        event.event_descriptor.event_status = event_status
+        if event_status == enums.EVENT_STATUS.CANCELLED:
+            event.event_descriptor.modification_number += 1
+        self.message_queues[ven_id].put_nowait(event)

+ 6 - 6
openleadr/templates/parts/eiActivePeriod.xml

@@ -13,19 +13,19 @@
             </tolerate>
         </tolerance>
         {% endif %}
-        {% if event.active_period.notification %}
+        {% if event.active_period.notification_period %}
         <ei:x-eiNotification>
-            <duration>{{ event.active_period.notification|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.notification_period|timedeltaformat }}</duration>
         </ei:x-eiNotification>
         {% endif %}
-        {% if event.active_period.ramp_up %}
+        {% if event.active_period.ramp_up_period %}
         <ei:x-eiRampUp>
-            <duration>{{ event.active_period.ramp_up|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.ramp_up_period|timedeltaformat }}</duration>
         </ei:x-eiRampUp>
         {% endif %}
-        {% if event.active_period.recovery %}
+        {% if event.active_period.recovery_period %}
         <ei:x-eiRecovery>
-            <duration>{{ event.active_period.recovery|timedeltaformat }}</duration>
+            <duration>{{ event.active_period.recovery_period|timedeltaformat }}</duration>
         </ei:x-eiRecovery>
         {% endif %}
     </properties>

+ 79 - 1
openleadr/utils.py

@@ -17,6 +17,7 @@
 from datetime import datetime, timedelta, timezone
 from dataclasses import is_dataclass, asdict
 from collections import OrderedDict
+import asyncio
 import itertools
 import re
 import ssl
@@ -298,7 +299,7 @@ def parse_datetime(value):
         year, month, day, hour, minute, second, micro = (int(value) for value in matches.groups())
         return datetime(year, month, day, hour, minute, second, micro, tzinfo=timezone.utc)
     else:
-        print(f"{value} did not match format")
+        logger.warning(f"parse_datetime: {value} did not match format")
         return value
 
 
@@ -599,3 +600,80 @@ def validate_report_measurement_dict(measurement):
             raise ValueError("A 'power' related measurement must contain a "
                              "'power_attributes' section that contains the following "
                              "keys: 'voltage' (int), 'ac' (boolean), 'hertz' (int)")
+
+
+def get_active_period_from_intervals(intervals, as_dict=True):
+    if is_dataclass(intervals[0]):
+        intervals = [asdict(i) for i in intervals]
+    period_start = min([i['dtstart'] for i in intervals])
+    period_duration = max([i['dtstart'] + i['duration'] - period_start for i in intervals])
+    if as_dict:
+        return {'dtstart': period_start,
+                'duration': period_duration}
+    else:
+        from openleadr.objects import ActivePeriod
+        return ActivePeriod(dtstart=period_start, duration=period_duration)
+
+
+def determine_event_status(active_period):
+    if is_dataclass(active_period):
+        active_period = asdict(active_period)
+    now = datetime.now(timezone.utc)
+    if active_period['dtstart'].tzinfo is None:
+        active_period['dtstart'] = active_period['dtstart'].astimezone(timezone.utc)
+    active_period_start = active_period['dtstart']
+    active_period_end = active_period['dtstart'] + active_period['duration']
+    if now >= active_period_end:
+        return 'completed'
+    if now >= active_period_start:
+        return 'active'
+    if active_period.get('ramp_up_duration') is not None:
+        ramp_up_start = active_period_start - active_period['ramp_up_duration']
+        if now >= ramp_up_start:
+            return 'near'
+    return 'far'
+
+
+async def delayed_call(func, delay):
+    if isinstance(delay, timedelta):
+        delay = delay.total_seconds()
+    await asyncio.sleep(delay)
+    if asyncio.iscoroutinefunction(func):
+        await func()
+    elif asyncio.iscoroutine(func):
+        await func
+    else:
+        func()
+
+
+def hasmember(obj, member):
+    """
+    Check if a dict or dataclass has the given member
+    """
+    if is_dataclass(obj):
+        if hasattr(obj, member):
+            return True
+    else:
+        if member in obj:
+            return True
+    return False
+
+
+def getmember(obj, member):
+    """
+    Get a member from a dict or dataclass
+    """
+    if is_dataclass(obj):
+        return getattr(obj, member)
+    else:
+        return obj[member]
+
+
+def setmember(obj, member, value):
+    """
+    Set a member of a dict of dataclass
+    """
+    if is_dataclass(obj):
+        setattr(obj, member, value)
+    else:
+        obj[member] = value

+ 72 - 24
test/integration_tests/test_event_warnings_errors.py

@@ -1,15 +1,16 @@
-from openleadr import OpenADRClient, OpenADRServer, enable_default_logging
+from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils
 import pytest
 from functools import partial
 import asyncio
-from datetime import datetime, timedelta
+from datetime import datetime, timedelta, timezone
 import logging
 
 async def on_create_party_registration(ven_name):
     return 'venid', 'regid'
 
-async def on_event_accepted(ven_id, event_id, opt_type, future):
-    future.set_result(opt_type)
+async def on_event_accepted(ven_id, event_id, opt_type, future=None):
+    if future and future.done() is False:
+        future.set_result(opt_type)
 
 async def good_on_event(event):
     return 'optIn'
@@ -41,8 +42,8 @@ async def test_client_no_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -55,7 +56,7 @@ async def test_client_no_event_handler(caplog):
             "choice. Will opt out of the event for now.") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_faulty_event_handler(caplog):
@@ -79,8 +80,8 @@ async def test_client_faulty_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -88,11 +89,11 @@ async def test_client_faulty_event_handler(caplog):
     print("Waiting for a response to the event")
     result = await event_confirm_future
     assert result == 'optOut'
-    assert ("Your on_event handler must return 'optIn' or 'optOut'; "
+    assert ("Your on_event or on_update_event handler must return 'optIn' or 'optOut'; "
            f"you supplied {None}. Please fix your on_event handler.") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_exception_event_handler(caplog):
@@ -116,8 +117,8 @@ async def test_client_exception_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -131,7 +132,7 @@ async def test_client_exception_event_handler(caplog):
            f"The error was {err.__class__.__name__}: {str(err)}") in [rec.message for rec in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
 
 @pytest.mark.asyncio
 async def test_client_good_event_handler(caplog):
@@ -155,8 +156,8 @@ async def test_client_good_event_handler(caplog):
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
                      callback=partial(on_event_accepted, future=event_confirm_future))
@@ -165,7 +166,6 @@ async def test_client_good_event_handler(caplog):
     result = await event_confirm_future
     assert result == 'optIn'
     assert len(caplog.records) == 0
-    await asyncio.sleep(1)
     await client.stop()
     await server.stop()
     await asyncio.sleep(1)
@@ -178,21 +178,69 @@ async def test_server_warning_conflicting_poll_methods(caplog):
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_handler('on_poll', print)
-    print("Running server")
-    await server.run_async()
-    await asyncio.sleep(0.1)
     server.add_event(ven_id='venid',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.now(),
-                                 'duration': timedelta(seconds=10),
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
                                  'signal_payload': 1.1}],
                      target={'ven_id': 'venid'},
-                     callback=print)
+                     callback=on_event_accepted)
     assert ("You cannot use the add_event method after you assign your own on_poll "
             "handler. If you use your own on_poll handler, you are responsible for "
             "delivering events from that handler. If you want to use OpenLEADRs "
             "message queuing system, you should not assign an on_poll handler. "
             "Your Event will NOT be added.") in [record.msg for record in caplog.records]
+
+
+@pytest.mark.asyncio
+async def test_server_warning_naive_datetimes_in_event(caplog):
+    caplog.set_level(logging.WARNING)
+    enable_default_logging()
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
+    server.add_event(ven_id='venid',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'},
+                     callback=on_event_accepted)
+    assert ("You supplied a naive datetime object to your interval's dtstart. "
+            "This will be interpreted as a timestamp in your local timezone "
+            "and then converted to UTC before sending. Please supply timezone-"
+            "aware timestamps like datetime.datetime.new(timezone.utc) or "
+            "datetime.datetime(..., tzinfo=datetime.timezone.utc)") in [record.msg for record in caplog.records]
+
+
+@pytest.mark.asyncio
+async def test_client_warning_no_update_event_handler(caplog):
+    caplog.set_level(logging.WARNING)
+    enable_default_logging()
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+    server.add_event(ven_id='venid',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'},
+                     callback=on_event_accepted)
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    client.add_handler('on_event', good_on_event)
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+    await asyncio.sleep(2)
+    assert ("You should implement your own on_update_event handler. This handler receives "
+            "an Event dict and should return either 'optIn' or 'optOut' based on your "
+            "choice. Will re-use the previous opt status for this event_id for now") in [record.msg for record in caplog.records]
+    await client.stop()
     await server.stop()
-    await asyncio.sleep(0)
+    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])

+ 83 - 0
test/test_client_misc.py

@@ -0,0 +1,83 @@
+import pytest
+from openleadr import OpenADRClient
+from openleadr import enums
+
+def test_trailing_slash_on_vtn_url():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost/')
+    assert client.vtn_url == 'http://localhost'
+
+def test_wrong_handler_supplied(caplog):
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    client.add_handler('non_existant', print)
+    assert ("'handler' must be either on_event or on_update_event") in [rec.message for rec in caplog.records]
+
+def test_invalid_report_name(caplog):
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          report_name='non_existant')
+        # assert (f"non_existant is not a valid report_name. Valid options are "
+        #         f"{', '.join(enums.REPORT_NAME.values)}",
+        #         " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_reading_type():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          reading_type='non_existant')
+            # assert (f"non_existant is not a valid reading_type. Valid options are "
+            # f"{', '.join(enums.READING_TYPE.values)}",
+            # " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_report_type():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          report_type='non_existant')
+        # assert (f"non_existant is not a valid report_type. Valid options are "
+        #         f"{', '.join(enums.REPORT_TYPE.values)}",
+        #         " or any name starting with 'x-'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_data_collection_mode():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          data_collection_mode='non_existant')
+        # assert ("The data_collection_mode should be 'incremental' or 'full'.") in [rec.message for rec in caplog.records]
+
+def test_invalid_scale():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(ValueError):
+        client.add_report(callback=print,
+                          resource_id='myresource',
+                          measurement='voltage',
+                          scale='non_existant')
+
+def test_add_report_without_specifier_id():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    client.add_report(callback=print,
+                      resource_id='myresource1',
+                      measurement='voltage')
+    client.add_report(callback=print,
+                      resource_id='myresource2',
+                      measurement='voltage')
+    assert len(client.reports) == 1
+
+async def wrong_sig(param1):
+    pass
+
+def test_add_report_with_invalid_callback_signature():
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost')
+    with pytest.raises(TypeError):
+        client.add_report(callback=wrong_sig,
+                          data_collection_mode='full',
+                          resource_id='myresource1',
+                          measurement='voltage')

+ 10 - 14
test/test_failures.py

@@ -27,7 +27,7 @@ async def test_signature_error(start_server_with_signatures):
     client.on_event = _client_on_event
     await client.run()
     await asyncio.sleep(0)
-    await client.client_session.close()
+    await client.stop()
 
 
 ##########################################################################################
@@ -61,25 +61,21 @@ async def _client_on_report(report):
 
 @pytest.fixture
 async def start_server():
-    server = OpenADRServer(vtn_id=VTN_ID)
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_host='localhost',
+                           http_port=SERVER_PORT)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
-    runner = web.AppRunner(server.app)
-    await runner.setup()
-    site = web.TCPSite(runner, 'localhost', SERVER_PORT)
-    await site.start()
+    await server.run_async()
     yield
-    await runner.cleanup()
+    await server.stop()
 
 @pytest.fixture
 async def start_server_with_signatures():
-    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr')
+    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr',
+                           http_host='localhost', http_port=SERVER_PORT)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
-    runner = web.AppRunner(server.app)
-    await runner.setup()
-    site = web.TCPSite(runner, 'localhost', SERVER_PORT)
-    await site.start()
+    await server.run_async()
     yield
-    await runner.cleanup()
-
+    await server.stop()

+ 125 - 4
test/test_queues.py

@@ -13,8 +13,28 @@ def on_create_party_registration(registration_info):
 async def on_event(event):
     return 'optIn'
 
+async def on_event_opt_in(event, future):
+    if future.done() is False:
+        future.set_result(event)
+    return 'optIn'
+
+async def on_update_event(event, futures):
+    for future in futures:
+        if future.done() is False:
+            future.set_result(event)
+            break
+    return 'optIn'
+
+async def on_event_opt_out(event, futures):
+    for future in futures:
+        if future.done() is False:
+            future.set_result(event)
+            break
+    return 'optOut'
+
 async def event_callback(ven_id, event_id, opt_type, future):
-    future.set_result(opt_type)
+    if future.done() is False:
+        future.set_result(opt_type)
 
 @pytest.mark.asyncio
 async def test_internal_message_queue():
@@ -28,8 +48,8 @@ async def test_internal_message_queue():
     server.add_event(ven_id='ven123',
                      signal_name='simple',
                      signal_type='level',
-                     intervals=[{'dtstart': datetime.datetime.now(),
-                                 'duration': datetime.timedelta(minutes=5),
+                     intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc),
+                                 'duration': datetime.timedelta(seconds=3),
                                  'signal_payload': 1}],
                      callback=partial(event_callback, future=event_callback_future))
 
@@ -46,5 +66,106 @@ async def test_internal_message_queue():
     message_type, message_payload = await asyncio.wait_for(client.poll(), 0.5)
     assert message_type == 'oadrResponse'
 
+    await asyncio.sleep(1)  # Wait for the event to be completed
+    await client.stop()
+    await server.stop()
+
+
+@pytest.mark.asyncio
+async def test_event_status_opt_in():
+    loop = asyncio.get_event_loop()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    distribute_event_future = loop.create_future()
+    event_update_futures = [loop.create_future() for i in range(2)]
+    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
+    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
+
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    event_callback_future = loop.create_future()
+    event_id = server.add_event(ven_id='ven123',
+                                signal_name='simple',
+                                signal_type='level',
+                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=2),
+                                            'duration': datetime.timedelta(seconds=2),
+                                            'signal_payload': 1}],
+                                callback=partial(event_callback, future=event_callback_future))
+
+    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+
+    await event_callback_future
+
+    print("Waiting for event future 1")
+    result = await distribute_event_future
+    assert result['event_descriptor']['event_status'] == 'far'
+    assert len(client.responded_events) == 1
+
+    print("Watiting for event future 2")
+    result = await event_update_futures[0]
+    assert result['event_descriptor']['event_status'] == 'active'
+    assert len(client.responded_events) == 1
+
+    print("Watiting for event future 3")
+    result = await event_update_futures[1]
+    assert result['event_descriptor']['event_status'] == 'completed'
+    assert len(client.responded_events) == 0
+
     await client.stop()
-    await server.stop()
+    await server.stop()
+    await asyncio.sleep(0)
+
+@pytest.mark.asyncio
+async def test_event_status_opt_in_with_ramp_up():
+    loop = asyncio.get_event_loop()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    distribute_event_future = loop.create_future()
+    event_update_futures = [loop.create_future() for i in range(3)]
+    client.add_handler('on_event', partial(on_event_opt_in, future=distribute_event_future))
+    client.add_handler('on_update_event', partial(on_update_event, futures=event_update_futures))
+
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    event_callback_future = loop.create_future()
+    event_id = server.add_event(ven_id='ven123',
+                                signal_name='simple',
+                                signal_type='level',
+                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=4),
+                                            'duration': datetime.timedelta(seconds=2),
+                                            'signal_payload': 1}],
+                                ramp_up_period=datetime.timedelta(seconds=2),
+                                callback=partial(event_callback, future=event_callback_future))
+
+    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    await server.run_async()
+    await asyncio.sleep(0.5)
+    await client.run()
+
+    await event_callback_future
+
+    print("Waiting for event future 1")
+    result = await distribute_event_future
+    assert result['event_descriptor']['event_status'] == 'far'
+
+    print("Watiting for event future 2")
+    result = await event_update_futures[0]
+    assert result['event_descriptor']['event_status'] == 'near'
+
+    print("Watiting for event future 3")
+    result = await event_update_futures[1]
+    assert result['event_descriptor']['event_status'] == 'active'
+
+    print("Watiting for event future 4")
+    result = await event_update_futures[2]
+    assert result['event_descriptor']['event_status'] == 'completed'
+    await asyncio.sleep(0.5)
+
+    await client.stop()
+    await server.stop()
+    await asyncio.sleep(1)