Browse Source

Improved testing of events and signatures

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 years ago
parent
commit
29959c8e2b

+ 2 - 0
openleadr/preflight.py

@@ -133,6 +133,8 @@ def _preflight_oadrDistributeEvent(message_payload):
     for event in message_payload['events']:
     for event in message_payload['events']:
         if 'created_date_time' not in event['event_descriptor'] \
         if 'created_date_time' not in event['event_descriptor'] \
                 or not event['event_descriptor']['created_date_time']:
                 or not event['event_descriptor']['created_date_time']:
+            logger.warning("Your event descriptor did not contain a created_date_time. "
+                           "This will be automatically added.")
             event['event_descriptor']['created_date_time'] = datetime.now(timezone.utc)
             event['event_descriptor']['created_date_time'] = datetime.now(timezone.utc)
 
 
     # Check that the target designations are correct and consistent
     # Check that the target designations are correct and consistent

+ 2 - 2
openleadr/service/vtn_service.py

@@ -136,7 +136,7 @@ class VTNService:
             response_payload['request_id'] = generate_id()
             response_payload['request_id'] = generate_id()
 
 
         else:
         else:
-            response_type, response_payload = self.error_response(message_type,
+            response_type, response_payload = self.error_response('oadrResponse',
                                                                   STATUS_CODES.COMPLIANCE_ERROR,
                                                                   STATUS_CODES.COMPLIANCE_ERROR,
                                                                   "A message of type "
                                                                   "A message of type "
                                                                   f"{message_type} should not be "
                                                                   f"{message_type} should not be "
@@ -151,5 +151,5 @@ class VTNService:
         else:
         else:
             response_type = 'oadrResponse'
             response_type = 'oadrResponse'
         response_payload = {'response': {'response_code': error_code,
         response_payload = {'response': {'response_code': error_code,
-                                         'response_description': 'Certificate fingerprint mismatch'}}
+                                         'response_description': error_description}}
         return response_type, response_payload
         return response_type, response_payload

+ 87 - 8
test/integration_tests/test_event_warnings_errors.py

@@ -1,10 +1,12 @@
-from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils
+from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils, messaging
 import pytest
 import pytest
 from functools import partial
 from functools import partial
 import asyncio
 import asyncio
 from datetime import datetime, timedelta, timezone
 from datetime import datetime, timedelta, timezone
 import logging
 import logging
 
 
+enable_default_logging()
+
 async def on_create_party_registration(ven_name):
 async def on_create_party_registration(ven_name):
     return 'venid', 'regid'
     return 'venid', 'regid'
 
 
@@ -24,7 +26,6 @@ async def broken_on_event(event):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_no_event_handler(caplog):
 async def test_client_no_event_handler(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
     client = OpenADRClient(ven_name='myven',
@@ -61,7 +62,6 @@ async def test_client_no_event_handler(caplog):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_faulty_event_handler(caplog):
 async def test_client_faulty_event_handler(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
     client = OpenADRClient(ven_name='myven',
@@ -98,7 +98,6 @@ async def test_client_faulty_event_handler(caplog):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_exception_event_handler(caplog):
 async def test_client_exception_event_handler(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
     client = OpenADRClient(ven_name='myven',
@@ -137,7 +136,6 @@ async def test_client_exception_event_handler(caplog):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_good_event_handler(caplog):
 async def test_client_good_event_handler(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
     client = OpenADRClient(ven_name='myven',
@@ -173,7 +171,6 @@ async def test_client_good_event_handler(caplog):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_server_warning_conflicting_poll_methods(caplog):
 async def test_server_warning_conflicting_poll_methods(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
@@ -196,7 +193,6 @@ async def test_server_warning_conflicting_poll_methods(caplog):
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_server_warning_naive_datetimes_in_event(caplog):
 async def test_server_warning_naive_datetimes_in_event(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
@@ -215,10 +211,93 @@ async def test_server_warning_naive_datetimes_in_event(caplog):
             "datetime.datetime(..., tzinfo=datetime.timezone.utc)") in [record.msg for record in caplog.records]
             "datetime.datetime(..., tzinfo=datetime.timezone.utc)") in [record.msg for record in caplog.records]
 
 
 
 
+def test_event_with_wrong_response_required(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'response_required': 'blabla'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert ("The response_required property in an Event should be "
+            "'never' or 'always', not blabla. Changing to 'always'.") in caplog.messages
+    message_type, message_payload= messaging.parse_message(msg)
+    assert message_payload['events'][0]['response_required'] == 'always'
+
+
+def test_event_missing_created_date_time(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far'},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'response_required': 'always'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert ("Your event descriptor did not contain a created_date_time. "
+            "This will be automatically added.") in caplog.messages
+
+
+def test_event_incongruent_targets(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'targets_by_type': {'ven_id': ['ven456']},
+             'response_required': 'always'}
+    with pytest.raises(ValueError) as err:
+        msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert str(err.value) == ("You assigned both 'targets' and 'targets_by_type' in your event, "
+                "but the two were not consistent with each other. "
+                f"You supplied 'targets' = {event['targets']} and "
+                f"'targets_by_type' = {event['targets_by_type']}")
+
+
+def test_event_only_targets_by_type(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets_by_type': {'ven_id': ['ven456']},
+             'response_required': 'always'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    message_type, message_payload = messaging.parse_message(msg)
+    assert message_payload['events'][0]['targets'] == [{'ven_id': 'ven456'}]
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_client_warning_no_update_event_handler(caplog):
 async def test_client_warning_no_update_event_handler(caplog):
     caplog.set_level(logging.WARNING)
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))

+ 139 - 14
test/test_failures.py

@@ -1,10 +1,14 @@
 from openleadr import OpenADRClient, OpenADRServer
 from openleadr import OpenADRClient, OpenADRServer
 from openleadr.utils import generate_id
 from openleadr.utils import generate_id
+from openleadr import messaging, errors
 import pytest
 import pytest
 from aiohttp import web
 from aiohttp import web
 import os
 import os
+import logging
 import asyncio
 import asyncio
 from datetime import timedelta
 from datetime import timedelta
+from base64 import b64encode
+import re
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_http_level_error(start_server):
 async def test_http_level_error(start_server):
@@ -13,21 +17,131 @@ async def test_http_level_error(start_server):
     await client.run()
     await client.run()
     await client.client_session.close()
     await client.client_session.close()
 
 
+
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_openadr_error(start_server):
-    client = OpenADRClient(vtn_url=f"http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b", ven_name=VEN_NAME)
-    client.on_event = _client_on_event
+async def test_xml_schema_error(start_server, caplog):
+    message = messaging.create_message("oadrQueryRegistration", request_id='req1234')
+    message = message.replace('<requestID xmlns="http://docs.oasis-open.org/ns/energyinterop/201110/payloads">req1234</requestID>', '')
+    client = OpenADRClient(ven_name='myven', vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    result = await client._perform_request('EiRegisterParty', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 400"):
+          assert "XML failed validation" in log
+          break
+    else:
+        assert False
+
+@pytest.mark.asyncio
+async def test_wrong_endpoint(start_server, caplog):
+    message = messaging.create_message("oadrQueryRegistration", request_id='req1234')
+    client = OpenADRClient(ven_name='myven', vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    response_type, response_payload = await client._perform_request('OadrPoll', message)
+    assert response_type == 'oadrResponse'
+    assert response_payload['response']['response_code'] == 459
+
+@pytest.mark.asyncio
+async def test_vtn_no_create_party_registration_handler(caplog):
+    caplog.set_level(logging.WARNING)
+    server = OpenADRServer(vtn_id='myvtn')
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    await server.run_async()
     await client.run()
     await client.run()
-    await client.client_session.close()
+    await asyncio.sleep(0.5)
+    await server.stop()
+    await client.stop()
+    assert len(caplog.messages) == 2
+    assert 'No VEN ID received from the VTN, aborting.' in caplog.messages
+    assert ("You should implement and register your own on_create_party_registration "
+            "handler if you want VENs to be able to connect to you. This handler will "
+            "receive a registration request and should return either 'False' (if the "
+            "registration is denied) or a (ven_id, registration_id) tuple if the "
+            "registration is accepted.") in caplog.messages
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_signature_error(start_server_with_signatures):
-    client = OpenADRClient(vtn_url=f"http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b", ven_name=VEN_NAME,
-                           vtn_fingerprint="INVALID")
-    client.on_event = _client_on_event
+async def test_invalid_signature_error(start_server_with_signatures, caplog):
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'https://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b',
+                           cert=VEN_CERT,
+                           key=VEN_KEY,
+                           vtn_fingerprint='EE:44:C5:78:7E:4B:B8:DC:84:1F')
+    message = client._create_message('oadrPoll', ven_id='ven123')
+    fake_sig = b64encode("HelloThere".encode('utf-8')).decode('utf-8')
+    message = re.sub(r'<ds:SignatureValue>.*?</ds:SignatureValue>', f'<ds:SignatureValue>{fake_sig}</ds:SignatureValue>', message)
+    result = await client._perform_request('OadrPoll', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 403 when performing a request"):
+          assert "Invalid Signature" in log
+          break
+    else:
+        assert False
+
+def problematic_handler(*args, **kwargs):
+    raise Exception("BOOM")
+
+@pytest.mark.asyncio
+async def test_server_handler_exception(caplog):
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_port=SERVER_PORT)
+    server.add_handler('on_create_party_registration', problematic_handler)
+    await server.run_async()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
     await client.run()
     await client.run()
-    await asyncio.sleep(0)
+    await asyncio.sleep(0.5)
     await client.stop()
     await client.stop()
+    await server.stop()
+    for message in caplog.messages:
+        if message.startswith('Non-OK status 500 when performing a request'):
+            break
+    else:
+        assert False
+
+def protocol_error_handler(*args, **kwargs):
+    raise errors.OutOfSequenceError()
+
+
+@pytest.mark.asyncio
+async def test_throw_protocol_error(caplog):
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_port=SERVER_PORT)
+    server.add_handler('on_create_party_registration', protocol_error_handler)
+    await server.run_async()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    await client.run()
+    await asyncio.sleep(0.5)
+    await client.stop()
+    await server.stop()
+    assert 'We got a non-OK OpenADR response from the server: 450: OUT OF SEQUENCE' in caplog.messages
+
+@pytest.mark.asyncio
+async def test_invalid_signature_error(start_server_with_signatures, caplog):
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'https://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b',
+                           cert=VEN_CERT,
+                           key=VEN_KEY,
+                           vtn_fingerprint='EE:44:C5:78:7E:4B:B8:DC:84:1F')
+    message = client._create_message('oadrPoll', ven_id='ven123')
+    fake_sig = b64encode("HelloThere".encode('utf-8')).decode('utf-8')
+    message = re.sub(r'<ds:SignatureValue>.*?</ds:SignatureValue>', f'<ds:SignatureValue>{fake_sig}</ds:SignatureValue>', message)
+    result = await client._perform_request('OadrPoll', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 403 when performing a request"):
+          assert "Invalid Signature" in log
+          break
+    else:
+        assert False
+
+
 
 
 
 
 ##########################################################################################
 ##########################################################################################
@@ -37,9 +151,11 @@ VEN_NAME = 'myven'
 VEN_ID = '1234abcd'
 VEN_ID = '1234abcd'
 VTN_ID = "TestVTN"
 VTN_ID = "TestVTN"
 
 
-CERTFILE = os.path.join(os.path.dirname(__file__), "cert.pem")
-KEYFILE =  os.path.join(os.path.dirname(__file__), "key.pem")
-
+VEN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ven.crt")
+VEN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ven.key")
+VTN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_vtn.crt")
+VTN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_vtn.key")
+CA_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ca.crt")
 
 
 async def _on_create_party_registration(payload):
 async def _on_create_party_registration(payload):
     registration_id = generate_id()
     registration_id = generate_id()
@@ -59,6 +175,9 @@ async def _client_on_event(event):
 async def _client_on_report(report):
 async def _client_on_report(report):
     pass
     pass
 
 
+def fingerprint_lookup(ven_id):
+    return "6B:C8:4E:47:15:AA:30:EB:CE:0E"
+
 @pytest.fixture
 @pytest.fixture
 async def start_server():
 async def start_server():
     server = OpenADRServer(vtn_id=VTN_ID, http_port=SERVER_PORT)
     server = OpenADRServer(vtn_id=VTN_ID, http_port=SERVER_PORT)
@@ -69,8 +188,14 @@ async def start_server():
 
 
 @pytest.fixture
 @pytest.fixture
 async def start_server_with_signatures():
 async def start_server_with_signatures():
-    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr',
-                           http_port=SERVER_PORT)
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           cert=VTN_CERT,
+                           key=VTN_KEY,
+                           http_cert=VTN_CERT,
+                           http_key=VTN_KEY,
+                           http_ca_file=CA_FILE,
+                           http_port=SERVER_PORT,
+                           fingerprint_lookup=fingerprint_lookup)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
 
     await server.run_async()
     await server.run_async()