Quellcode durchsuchen

Added additional checks to the event callback

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen vor 4 Jahren
Ursprung
Commit
3029d96dad

+ 15 - 5
openleadr/server.py

@@ -27,6 +27,7 @@ import asyncio
 import logging
 import ssl
 import re
+import inspect
 logger = logging.getLogger('openleadr')
 
 
@@ -143,7 +144,8 @@ class OpenADRServer:
         await self.run()
 
     async def stop(self):
-        delayed_call_tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith('DelayedCall')]
+        delayed_call_tasks = [task for task in asyncio.all_tasks()
+                              if task.get_name().startswith('DelayedCall')]
         for task in delayed_call_tasks:
             task.cancel()
         await self.app_runner.cleanup()
@@ -182,8 +184,8 @@ class OpenADRServer:
         event_id = event_id or utils.generate_id()
 
         if response_required not in ('always', 'never'):
-            raise ValueError(f"'response_required' should be either 'always' or 'never', "
-                             "you provided {response_required}.")
+            raise ValueError("'response_required' should be either 'always' or 'never'; "
+                             f"you provided '{response_required}'.")
 
         # Figure out the target for this Event
         if target is None and targets is None and targets_by_type is None:
@@ -205,6 +207,7 @@ class OpenADRServer:
                                            signal_type=signal_type,
                                            signal_id=utils.generate_id(),
                                            targets=targets)
+
         # Make sure the intervals carry timezone-aware timestamps
         for interval in intervals:
             if utils.getmember(interval, 'dtstart').tzinfo is None:
@@ -241,6 +244,13 @@ class OpenADRServer:
                 logger.warning("You did not provide a 'callback', which means you won't know if the "
                                "VEN will opt in or opt out of your event. You should consider adding "
                                "a callback for this.")
+            elif not asyncio.isfuture(callback):
+                args = inspect.signature(callback).parameters
+                if not all(['ven_id' in args, 'event_id' in args, 'opt_type' in args]):
+                    raise ValueError("The 'callback' must have at least the following parameters: "
+                                     "'ven_id' (str), 'event_id' (str), 'opt_type' (str). Please fix "
+                                     "your 'callback' handler.")
+
         if ven_id not in self.message_queues:
             self.message_queues[ven_id] = deque()
         event_id = utils.getmember(utils.getmember(event, 'event_descriptor'), 'event_id')
@@ -270,5 +280,5 @@ class OpenADRServer:
                 self.services['poll_service'].polling_method = 'external'
                 self.services['event_service'].polling_method = 'external'
         else:
-            raise NameError(f"Unknown handler {name}. "
-                            f"Correct handler names are: {self._MAP.keys()}")
+            raise NameError(f"""Unknown handler '{name}'. """
+                            f"""Correct handler names are: '{"', '".join(self._MAP.keys())}'.""")

+ 67 - 1
test/integration_tests/test_event_warnings_errors.py

@@ -41,6 +41,7 @@ async def test_client_no_event_handler(caplog):
     event_confirm_future = asyncio.get_event_loop().create_future()
     print("Adding event")
     server.add_event(ven_id='venid',
+                     event_id='test_client_no_event_handler',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -78,6 +79,7 @@ async def test_client_faulty_event_handler(caplog):
     event_confirm_future = asyncio.get_event_loop().create_future()
     print("Adding event")
     server.add_event(ven_id='venid',
+                     event_id='test_client_faulty_event_handler',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -114,6 +116,7 @@ async def test_client_exception_event_handler(caplog):
     event_confirm_future = asyncio.get_event_loop().create_future()
     print("Adding event")
     server.add_event(ven_id='venid',
+                     event_id='test_client_exception_event_handler',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -152,6 +155,7 @@ async def test_client_good_event_handler(caplog):
     event_confirm_future = asyncio.get_event_loop().create_future()
     print("Adding event")
     server.add_event(ven_id='venid',
+                     event_id='test_client_good_event_handler',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -176,6 +180,7 @@ async def test_server_warning_conflicting_poll_methods(caplog):
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_handler('on_poll', print)
     server.add_event(ven_id='venid',
+                     event_id='test_server_warning_conflicting_poll_methods',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -197,6 +202,7 @@ async def test_server_warning_naive_datetimes_in_event(caplog):
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_event(ven_id='venid',
+                     event_id='test_server_warning_naive_datetimes_in_event',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(),
@@ -303,6 +309,7 @@ async def test_client_warning_no_update_event_handler(caplog):
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server.add_handler('on_create_party_registration', on_create_party_registration)
     server.add_event(ven_id='venid',
+                     event_id='test_client_warning_no_update_event_handler',
                      signal_name='simple',
                      signal_type='level',
                      intervals=[{'dtstart': datetime.now(timezone.utc),
@@ -322,4 +329,63 @@ async def test_client_warning_no_update_event_handler(caplog):
             "choice. Will re-use the previous opt status for this event_id for now") in [record.msg for record in caplog.records]
     await client.stop()
     await server.stop()
-    await asyncio.gather(*[t for t in asyncio.all_tasks()][1:])
+
+@pytest.mark.asyncio
+async def test_server_add_event_with_wrong_callback_signature(caplog):
+    def dummy_callback(some_param):
+        pass
+    caplog.set_level(logging.WARNING)
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
+    with pytest.raises(ValueError) as err:
+        server.add_event(ven_id='venid',
+                         event_id='test_server_add_event_with_wrong_callback_signature',
+                         signal_name='simple',
+                         signal_type='level',
+                         intervals=[{'dtstart': datetime.now(timezone.utc),
+                                     'duration': timedelta(seconds=1),
+                                     'signal_payload': 1.1}],
+                         target={'ven_id': 'venid'},
+                         callback=dummy_callback)
+
+@pytest.mark.asyncio
+async def test_server_add_event_with_no_callback(caplog):
+    def dummy_callback(some_param):
+        pass
+    caplog.set_level(logging.WARNING)
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn')
+    server.add_event(ven_id='venid',
+                     event_id='test_server_add_event_with_no_callback',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'})
+    assert ("You did not provide a 'callback', which means you won't know if the "
+            "VEN will opt in or opt out of your event. You should consider adding "
+            "a callback for this.") in caplog.messages
+
+@pytest.mark.asyncio
+async def test_server_add_event_with_no_callback_response_never_required(caplog):
+    caplog.set_level(logging.WARNING)
+    logger = logging.getLogger('openleadr')
+    logger.setLevel(logging.DEBUG)
+    server = OpenADRServer(vtn_id='myvtn')
+    server.add_event(ven_id='venid',
+                     event_id='test_server_add_event_with_no_callback_response_never_required',
+                     signal_name='simple',
+                     signal_type='level',
+                     intervals=[{'dtstart': datetime.now(timezone.utc),
+                                 'duration': timedelta(seconds=1),
+                                 'signal_payload': 1.1}],
+                     target={'ven_id': 'venid'},
+                     response_required='never')
+    await server.run()
+    await server.stop()
+    assert ("You did not provide a 'callback', which means you won't know if the "
+            "VEN will opt in or opt out of your event. You should consider adding "
+            "a callback for this.") not in caplog.messages

+ 12 - 0
test/test_failures.py

@@ -204,6 +204,18 @@ def test_replay_protect_malformed_nonce(caplog):
         messaging._verify_replay_protect(tree)
     assert str(err.value) == "Missing or malformed ReplayProtect element in the message signature."
 
+
+def test_server_add_unknown_handler(caplog):
+    server = OpenADRServer(vtn_id='myvtn')
+    with pytest.raises(NameError) as err:
+        server.add_handler('unknown_name', print)
+    assert str(err.value) == ("Unknown handler 'unknown_name'. Correct handler names are: "
+                              "'on_created_event', 'on_request_event', 'on_register_report', "
+                              "'on_create_report', 'on_created_report', 'on_request_report', "
+                              "'on_update_report', 'on_poll', 'on_query_registration', "
+                              "'on_create_party_registration', 'on_cancel_party_registration'.")
+
+
 ##########################################################################################
 
 SERVER_PORT = 8001

+ 1 - 1
test/test_queues.py

@@ -196,6 +196,7 @@ async def test_request_event():
     assert message_type == 'oadrDistributeEvent'
     message_type, message_payload = await client.request_event()
     assert message_type == 'oadrResponse'
+    await client.stop()
     await server.stop()
 
 
@@ -280,6 +281,5 @@ async def test_create_event_with_future_as_callback():
 
     result = await event_callback_future
     assert result == 'optIn'
-
     await client.stop()
     await server.stop()