浏览代码

Improved testing of events and signatures

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 年之前
父节点
当前提交
29959c8e2b
共有 4 个文件被更改,包括 230 次插入24 次删除
  1. 2 0
      openleadr/preflight.py
  2. 2 2
      openleadr/service/vtn_service.py
  3. 87 8
      test/integration_tests/test_event_warnings_errors.py
  4. 139 14
      test/test_failures.py

+ 2 - 0
openleadr/preflight.py

@@ -133,6 +133,8 @@ def _preflight_oadrDistributeEvent(message_payload):
     for event in message_payload['events']:
         if 'created_date_time' not in event['event_descriptor'] \
                 or not event['event_descriptor']['created_date_time']:
+            logger.warning("Your event descriptor did not contain a created_date_time. "
+                           "This will be automatically added.")
             event['event_descriptor']['created_date_time'] = datetime.now(timezone.utc)
 
     # Check that the target designations are correct and consistent

+ 2 - 2
openleadr/service/vtn_service.py

@@ -136,7 +136,7 @@ class VTNService:
             response_payload['request_id'] = generate_id()
 
         else:
-            response_type, response_payload = self.error_response(message_type,
+            response_type, response_payload = self.error_response('oadrResponse',
                                                                   STATUS_CODES.COMPLIANCE_ERROR,
                                                                   "A message of type "
                                                                   f"{message_type} should not be "
@@ -151,5 +151,5 @@ class VTNService:
         else:
             response_type = 'oadrResponse'
         response_payload = {'response': {'response_code': error_code,
-                                         'response_description': 'Certificate fingerprint mismatch'}}
+                                         'response_description': error_description}}
         return response_type, response_payload

+ 87 - 8
test/integration_tests/test_event_warnings_errors.py

@@ -1,10 +1,12 @@
-from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils
+from openleadr import OpenADRClient, OpenADRServer, enable_default_logging, utils, messaging
 import pytest
 from functools import partial
 import asyncio
 from datetime import datetime, timedelta, timezone
 import logging
 
+enable_default_logging()
+
 async def on_create_party_registration(ven_name):
     return 'venid', 'regid'
 
@@ -24,7 +26,6 @@ async def broken_on_event(event):
 @pytest.mark.asyncio
 async def test_client_no_event_handler(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
@@ -61,7 +62,6 @@ async def test_client_no_event_handler(caplog):
 @pytest.mark.asyncio
 async def test_client_faulty_event_handler(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
@@ -98,7 +98,6 @@ async def test_client_faulty_event_handler(caplog):
 @pytest.mark.asyncio
 async def test_client_exception_event_handler(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
@@ -137,7 +136,6 @@ async def test_client_exception_event_handler(caplog):
 @pytest.mark.asyncio
 async def test_client_good_event_handler(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     client = OpenADRClient(ven_name='myven',
@@ -173,7 +171,6 @@ async def test_client_good_event_handler(caplog):
 @pytest.mark.asyncio
 async def test_server_warning_conflicting_poll_methods(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
@@ -196,7 +193,6 @@ async def test_server_warning_conflicting_poll_methods(caplog):
 @pytest.mark.asyncio
 async def test_server_warning_naive_datetimes_in_event(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))
@@ -215,10 +211,93 @@ async def test_server_warning_naive_datetimes_in_event(caplog):
             "datetime.datetime(..., tzinfo=datetime.timezone.utc)") in [record.msg for record in caplog.records]
 
 
+def test_event_with_wrong_response_required(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'response_required': 'blabla'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert ("The response_required property in an Event should be "
+            "'never' or 'always', not blabla. Changing to 'always'.") in caplog.messages
+    message_type, message_payload= messaging.parse_message(msg)
+    assert message_payload['events'][0]['response_required'] == 'always'
+
+
+def test_event_missing_created_date_time(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far'},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'response_required': 'always'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert ("Your event descriptor did not contain a created_date_time. "
+            "This will be automatically added.") in caplog.messages
+
+
+def test_event_incongruent_targets(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets': [{'ven_id': 'ven123'}],
+             'targets_by_type': {'ven_id': ['ven456']},
+             'response_required': 'always'}
+    with pytest.raises(ValueError) as err:
+        msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    assert str(err.value) == ("You assigned both 'targets' and 'targets_by_type' in your event, "
+                "but the two were not consistent with each other. "
+                f"You supplied 'targets' = {event['targets']} and "
+                f"'targets_by_type' = {event['targets_by_type']}")
+
+
+def test_event_only_targets_by_type(caplog):
+    now = datetime.now(timezone.utc)
+    event = {'active_period': {'dtstart': now, 'duration': timedelta(seconds=10)},
+             'event_descriptor': {'event_id': 'event123',
+                                  'modification_number': 1,
+                                  'priority': 0,
+                                  'event_status': 'far',
+                                  'created_date_time': now},
+             'event_signals': [{'signal_name': 'simple',
+                                'signal_type': 'level',
+                                'intervals': [{'dtstart': now,
+                                               'duration': timedelta(seconds=10),
+                                               'signal_payload': 1}]}],
+             'targets_by_type': {'ven_id': ['ven456']},
+             'response_required': 'always'}
+    msg = messaging.create_message('oadrDistributeEvent', events=[event])
+    message_type, message_payload = messaging.parse_message(msg)
+    assert message_payload['events'][0]['targets'] == [{'ven_id': 'ven456'}]
+
 @pytest.mark.asyncio
 async def test_client_warning_no_update_event_handler(caplog):
     caplog.set_level(logging.WARNING)
-    enable_default_logging()
     logger = logging.getLogger('openleadr')
     logger.setLevel(logging.DEBUG)
     server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=timedelta(seconds=1))

+ 139 - 14
test/test_failures.py

@@ -1,10 +1,14 @@
 from openleadr import OpenADRClient, OpenADRServer
 from openleadr.utils import generate_id
+from openleadr import messaging, errors
 import pytest
 from aiohttp import web
 import os
+import logging
 import asyncio
 from datetime import timedelta
+from base64 import b64encode
+import re
 
 @pytest.mark.asyncio
 async def test_http_level_error(start_server):
@@ -13,21 +17,131 @@ async def test_http_level_error(start_server):
     await client.run()
     await client.client_session.close()
 
+
 @pytest.mark.asyncio
-async def test_openadr_error(start_server):
-    client = OpenADRClient(vtn_url=f"http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b", ven_name=VEN_NAME)
-    client.on_event = _client_on_event
+async def test_xml_schema_error(start_server, caplog):
+    message = messaging.create_message("oadrQueryRegistration", request_id='req1234')
+    message = message.replace('<requestID xmlns="http://docs.oasis-open.org/ns/energyinterop/201110/payloads">req1234</requestID>', '')
+    client = OpenADRClient(ven_name='myven', vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    result = await client._perform_request('EiRegisterParty', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 400"):
+          assert "XML failed validation" in log
+          break
+    else:
+        assert False
+
+@pytest.mark.asyncio
+async def test_wrong_endpoint(start_server, caplog):
+    message = messaging.create_message("oadrQueryRegistration", request_id='req1234')
+    client = OpenADRClient(ven_name='myven', vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    response_type, response_payload = await client._perform_request('OadrPoll', message)
+    assert response_type == 'oadrResponse'
+    assert response_payload['response']['response_code'] == 459
+
+@pytest.mark.asyncio
+async def test_vtn_no_create_party_registration_handler(caplog):
+    caplog.set_level(logging.WARNING)
+    server = OpenADRServer(vtn_id='myvtn')
+    client = OpenADRClient(ven_name='myven', vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+    await server.run_async()
     await client.run()
-    await client.client_session.close()
+    await asyncio.sleep(0.5)
+    await server.stop()
+    await client.stop()
+    assert len(caplog.messages) == 2
+    assert 'No VEN ID received from the VTN, aborting.' in caplog.messages
+    assert ("You should implement and register your own on_create_party_registration "
+            "handler if you want VENs to be able to connect to you. This handler will "
+            "receive a registration request and should return either 'False' (if the "
+            "registration is denied) or a (ven_id, registration_id) tuple if the "
+            "registration is accepted.") in caplog.messages
 
 @pytest.mark.asyncio
-async def test_signature_error(start_server_with_signatures):
-    client = OpenADRClient(vtn_url=f"http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b", ven_name=VEN_NAME,
-                           vtn_fingerprint="INVALID")
-    client.on_event = _client_on_event
+async def test_invalid_signature_error(start_server_with_signatures, caplog):
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'https://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b',
+                           cert=VEN_CERT,
+                           key=VEN_KEY,
+                           vtn_fingerprint='EE:44:C5:78:7E:4B:B8:DC:84:1F')
+    message = client._create_message('oadrPoll', ven_id='ven123')
+    fake_sig = b64encode("HelloThere".encode('utf-8')).decode('utf-8')
+    message = re.sub(r'<ds:SignatureValue>.*?</ds:SignatureValue>', f'<ds:SignatureValue>{fake_sig}</ds:SignatureValue>', message)
+    result = await client._perform_request('OadrPoll', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 403 when performing a request"):
+          assert "Invalid Signature" in log
+          break
+    else:
+        assert False
+
+def problematic_handler(*args, **kwargs):
+    raise Exception("BOOM")
+
+@pytest.mark.asyncio
+async def test_server_handler_exception(caplog):
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_port=SERVER_PORT)
+    server.add_handler('on_create_party_registration', problematic_handler)
+    await server.run_async()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
     await client.run()
-    await asyncio.sleep(0)
+    await asyncio.sleep(0.5)
     await client.stop()
+    await server.stop()
+    for message in caplog.messages:
+        if message.startswith('Non-OK status 500 when performing a request'):
+            break
+    else:
+        assert False
+
+def protocol_error_handler(*args, **kwargs):
+    raise errors.OutOfSequenceError()
+
+
+@pytest.mark.asyncio
+async def test_throw_protocol_error(caplog):
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           http_port=SERVER_PORT)
+    server.add_handler('on_create_party_registration', protocol_error_handler)
+    await server.run_async()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'http://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b')
+    await client.run()
+    await asyncio.sleep(0.5)
+    await client.stop()
+    await server.stop()
+    assert 'We got a non-OK OpenADR response from the server: 450: OUT OF SEQUENCE' in caplog.messages
+
+@pytest.mark.asyncio
+async def test_invalid_signature_error(start_server_with_signatures, caplog):
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url=f'https://localhost:{SERVER_PORT}/OpenADR2/Simple/2.0b',
+                           cert=VEN_CERT,
+                           key=VEN_KEY,
+                           vtn_fingerprint='EE:44:C5:78:7E:4B:B8:DC:84:1F')
+    message = client._create_message('oadrPoll', ven_id='ven123')
+    fake_sig = b64encode("HelloThere".encode('utf-8')).decode('utf-8')
+    message = re.sub(r'<ds:SignatureValue>.*?</ds:SignatureValue>', f'<ds:SignatureValue>{fake_sig}</ds:SignatureValue>', message)
+    result = await client._perform_request('OadrPoll', message)
+    assert result == (None, {})
+
+    logs = [rec.message for rec in caplog.records]
+    for log in logs:
+        if log.startswith("Non-OK status 403 when performing a request"):
+          assert "Invalid Signature" in log
+          break
+    else:
+        assert False
+
+
 
 
 ##########################################################################################
@@ -37,9 +151,11 @@ VEN_NAME = 'myven'
 VEN_ID = '1234abcd'
 VTN_ID = "TestVTN"
 
-CERTFILE = os.path.join(os.path.dirname(__file__), "cert.pem")
-KEYFILE =  os.path.join(os.path.dirname(__file__), "key.pem")
-
+VEN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ven.crt")
+VEN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ven.key")
+VTN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_vtn.crt")
+VTN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_vtn.key")
+CA_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), "certificates", "dummy_ca.crt")
 
 async def _on_create_party_registration(payload):
     registration_id = generate_id()
@@ -59,6 +175,9 @@ async def _client_on_event(event):
 async def _client_on_report(report):
     pass
 
+def fingerprint_lookup(ven_id):
+    return "6B:C8:4E:47:15:AA:30:EB:CE:0E"
+
 @pytest.fixture
 async def start_server():
     server = OpenADRServer(vtn_id=VTN_ID, http_port=SERVER_PORT)
@@ -69,8 +188,14 @@ async def start_server():
 
 @pytest.fixture
 async def start_server_with_signatures():
-    server = OpenADRServer(vtn_id=VTN_ID, cert=CERTFILE, key=KEYFILE, passphrase='openadr',
-                           http_port=SERVER_PORT)
+    server = OpenADRServer(vtn_id=VTN_ID,
+                           cert=VTN_CERT,
+                           key=VTN_KEY,
+                           http_cert=VTN_CERT,
+                           http_key=VTN_KEY,
+                           http_ca_file=CA_FILE,
+                           http_port=SERVER_PORT,
+                           fingerprint_lookup=fingerprint_lookup)
     server.add_handler('on_create_party_registration', _on_create_party_registration)
 
     await server.run_async()