Переглянути джерело

More careful checking of the on_register_report returned data.

There was a problem that was reported in #38 which caused the VTN's register_report handler to not work properly if something unexpected was returned from the on_register_report handler. This should provide the user of openleadr with more helpful error messages in those cases.

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 роки тому
батько
коміт
5f99039df0

+ 16 - 13
openleadr/service/report_service.py

@@ -16,8 +16,7 @@
 
 from . import service, handler, VTNService
 from asyncio import iscoroutine, gather
-from openleadr.utils import generate_id, find_by, group_by
-from openleadr import objects
+from openleadr import objects, utils
 import logging
 import inspect
 logger = logging.getLogger('openleadr')
@@ -112,22 +111,26 @@ class ReportService(VTNService):
 
                 if iscoroutine(result[0]):
                     result = await gather(*result)
+                for i, r in enumerate(result):
+                    if not isinstance(r, tuple):
+                        logger.error(f"Your on_register_report handler must return a tuple; it returned '{r}' ({r.__class__.__name__}).")
+                        result[i] = None
                 result = [(report['report_descriptions'][i]['r_id'], *result[i])
-                          for i in range(len(report['report_descriptions'])) if result[i] is not None]
+                          for i in range(len(report['report_descriptions'])) if isinstance(result[i], tuple)]
                 report_requests.append(result)
         else:
             # Use the 'full' mode for openADR reporting
             result = [self.on_register_report(report) for report in payload['reports']]
             if iscoroutine(result[0]):
                 result = await gather(*result)      # Now we have r_id, callback, sampling_rate
+            for i, r in enumerate(result):
+                if not isinstance(r, list):
+                    logger.error(f"Your on_register_report handler must return a list of tuples. It returned '{r}' ({r.__class__.__name__}).")
+                    result[i] = None
             report_requests = result
 
-        for i, report_request in enumerate(report_requests):
-            if report_request is not None:
-                if not all(len(rrq) in (3, 4) for rrq in report_request):
-                    logger.error("Your on_register_report handler did not return a valid response")
-
-        # Validate the report requests
+        # Validate the report requests for being of the proper type and lengs
+        utils.validate_report_request_tuples(report_requests)
         for i, report_request in enumerate(report_requests):
             if report_request is None or len(report_request) == 0:
                 continue
@@ -141,12 +144,12 @@ class ReportService(VTNService):
         # Form the report request
         oadr_report_requests = []
         for i, report_request in enumerate(report_requests):
-            if report_request is None:
+            if report_request is None or len(report_request) == 0:
                 continue
 
             orig_report = payload['reports'][i]
             report_specifier_id = orig_report['report_specifier_id']
-            report_request_id = generate_id()
+            report_request_id = utils.generate_id()
             specifier_payloads = []
             for rrq in report_request:
                 if len(rrq) == 3:
@@ -155,7 +158,7 @@ class ReportService(VTNService):
                 elif len(rrq) == 4:
                     r_id, callback, sampling_interval, report_interval = rrq
 
-                report_description = find_by(orig_report['report_descriptions'], 'r_id', r_id)
+                report_description = utils.find_by(orig_report['report_descriptions'], 'r_id', r_id)
                 reading_type = report_description['reading_type']
                 specifier_payloads.append(objects.SpecifierPayload(r_id=r_id,
                                                                    reading_type=reading_type))
@@ -202,7 +205,7 @@ class ReportService(VTNService):
                 if iscoroutine(result):
                     result = await result
                 continue
-            for r_id, values in group_by(report['intervals'], 'report_payload.r_id').items():
+            for r_id, values in utils.group_by(report['intervals'], 'report_payload.r_id').items():
                 # Find the callback that was registered.
                 if (report_request_id, r_id) in self.report_callbacks:
                     # Collect the values

+ 48 - 0
openleadr/utils.py

@@ -646,3 +646,51 @@ def get_next_event_from_deque(deque):
             unused_elements.append(msg)
     deque.extend(unused_elements)
     return event
+
+def validate_report_request_tuples(list_of_report_requests):
+    if len(list_of_report_requests) == 0:
+        return
+    for report_requests in list_of_report_requests:
+        if report_requests is None:
+            continue
+        for i, rrq in enumerate(report_requests):
+            if rrq is None:
+                continue
+
+            # Check if it is a tuple
+            elif not isinstance(rrq, tuple):
+                report_requests[i] = None
+                logger.error(f"Your on_register_report did not return a tuple. It returned '{rrq}'.")
+
+            # Check if it has the correct length
+            elif not len(rrq) in (3, 4):
+                report_requests[i] = None
+                logger.error("Your on_register_report returned a tuple of the wrong length. "
+                             f"It should be 2 or 3. It returned '{rrq}'.")
+
+            # Check if the first element is callable
+            elif not callable(rrq[1]):
+                report_requests[i] = None
+                logger.error(f"Your on_register_report did not return the correct tuple. "
+                             "It should return a (callback, sampling_interval) or "
+                             "(callback, sampling_interval, reporting_interval) tuple, where "
+                             "sampling_interval and reporting_interval are of type datetime.timedelta. "
+                             f"It returned: '{rrq}'. The first element was not callable.")
+
+            # Check if the second element is a timedelta
+            elif not isinstance(rrq[2], timedelta):
+                report_requests[i] = None
+                logger.error(f"Your on_register_report did not return the correct tuple. "
+                             "It should return a (callback, sampling_interval) or "
+                             "(callback, sampling_interval, reporting_interval) tuple, where "
+                             "sampling_interval and reporting_interval are of type datetime.timedelta. "
+                             f"It returned: '{rrq}'. The second element was not of type timedelta.")
+
+            # Check if the third element is a timedelta (if it exists)
+            elif len(rrq) == 4 and not isinstance(rrq[3], timedelta):
+                report_requests[i] = None
+                logger.error(f"Your on_register_report did not return the correct tuple. "
+                             "It should return a (callback, sampling_interval) or "
+                             "(callback, sampling_interval, reporting_interval) tuple, where "
+                             "sampling_interval and reporting_interval are of type datetime.timedelta. "
+                             f"It returned: '{rrq}'. The third element was not of type timedelta.")

+ 3 - 3
test/test_message_conversion.py

@@ -73,7 +73,7 @@ def create_dummy_event(ven_id):
 reports = [{'report_id': generate_id(),
             'duration': timedelta(seconds=3600),
             'report_descriptions': [{'r_id': generate_id(),
-                                     'report_subject': {'resource_id': 'resource001'},
+                                     'report_subject': {'end_device_asset': {'mrid': 'meter001'}},
                                      'report_data_source': {'resource_id': 'resource001'},
                                      'report_type': 'usage',
                                      'measurement': asdict(measurement),
@@ -174,7 +174,7 @@ testcases = [
 ('oadrRegisterReport', dict(request_id=generate_id(), reports=[{'report_id': generate_id(),
                                                                 'report_descriptions': [{
                                                                      'r_id': generate_id(),
-                                                                     'report_subject': {'resource_id': '123ABC'},
+                                                                     'report_subject': {'end_device_asset': {'mrid': 'meter001'}},
                                                                      'report_data_source': {'resource_id': '123ABC'},
                                                                      'report_type': 'reading',
                                                                      'reading_type': 'Direct Read',
@@ -266,7 +266,7 @@ testcases = [
                                                               'report_request_id': generate_id(),
                                                               'report_specifier_id': generate_id(),
                                                               'report_descriptions': [{'r_id': generate_id(),
-                                                                                       'report_subject': {'resource_id': '123ABC'},
+                                                                                       'report_subject': {'end_device_asset': {'mrid': 'meter001'}},
                                                                                        'report_data_source': {'resource_id': '123ABC'},
                                                                                        'report_type': enums.REPORT_TYPE.values[0],
                                                                                        'reading_type': enums.READING_TYPE.values[0],

+ 130 - 6
test/test_reports.py

@@ -1,6 +1,7 @@
 from openleadr import OpenADRClient, OpenADRServer, enable_default_logging
 import asyncio
 import pytest
+import aiohttp
 from datetime import datetime, timedelta
 from functools import partial
 import logging
@@ -46,6 +47,7 @@ async def on_register_report(ven_id, resource_id, measurement, unit, scale,
     Deal with this report.
     """
     print(f"Called on register report {ven_id}, {resource_id}, {measurement}, {unit}, {scale}, {min_sampling_interval}, {max_sampling_interval}")
+    assert resource_id in ('Device001', 'Device002')
     if futures:
         futures.pop(0).set_result(True)
     if receive_futures:
@@ -345,7 +347,7 @@ async def test_incremental_reports():
     client.add_report(callback=partial(collect_data_multi, futures=collect_futures),
                       report_specifier_id='myhistory',
                       measurement='voltage',
-                      resource_id='mydevice',
+                      resource_id='Device001',
                       sampling_rate=timedelta(seconds=2))
 
     server = OpenADRServer(vtn_id='myvtn')
@@ -457,7 +459,7 @@ def test_add_report_invalid_unit(caplog):
     client.add_report(callback=print,
                       report_specifier_id='myreport',
                       measurement='voltage',
-                      resource_id='mydevice',
+                      resource_id='Device001',
                       sampling_rate=timedelta(seconds=10),
                       unit='A')
     assert caplog.record_tuples == [("openleadr", logging.WARNING, f"The supplied unit A for measurement voltage will be ignored, V will be used instead. Allowed units for this measurement are: V")]
@@ -468,7 +470,7 @@ def test_add_report_invalid_scale():
         client.add_report(callback=print,
                           report_specifier_id='myreport',
                           measurement='power_real',
-                          resource_id='mydevice',
+                          resource_id='Device001',
                           sampling_rate=timedelta(seconds=10),
                           unit='W',
                           scale='xxx')
@@ -478,7 +480,7 @@ def test_add_report_invalid_description(caplog):
     client.add_report(callback=print,
                       report_specifier_id='myreport',
                       measurement={'name': 'voltage', 'description': 'SomethingWrong', 'unit': 'V'},
-                      resource_id='mydevice',
+                      resource_id='Device001',
                       sampling_rate=timedelta(seconds=10))
     msg = create_message('oadrRegisterReport', reports=client.reports)
 
@@ -489,7 +491,7 @@ def test_add_report_invalid_description(caplog):
         client.add_report(callback=print,
                           report_specifier_id='myreport',
                           measurement={'name': 'voltage', 'description': 'SomethingWrong', 'unit': 'V'},
-                          resource_id='mydevice',
+                          resource_id='Device001',
                           sampling_rate=timedelta(seconds=10))
 
 
@@ -498,7 +500,7 @@ def test_add_report_non_standard_measurement():
     client.add_report(callback=print,
                       report_specifier_id='myreport',
                       measurement='rainbows',
-                      resource_id='mydevice',
+                      resource_id='Device001',
                       sampling_rate=timedelta(seconds=10),
                       unit='A')
     assert len(client.reports) == 1
@@ -506,6 +508,128 @@ def test_add_report_non_standard_measurement():
     assert client.reports[0].report_descriptions[0].measurement.description == 'rainbows'
 
 
+async def test_report_registration_broken_handlers(caplog):
+    msg = """<?xml version="1.0" encoding="UTF-8" standalone="no" ?>
+<p1:oadrPayload xmlns:p1="http://openadr.org/oadr-2.0b/2012/07">
+  <p1:oadrSignedObject>
+    <p1:oadrRegisterReport xmlns:p3="http://docs.oasis-open.org/ns/energyinterop/201110" p3:schemaVersion="2.0b" xmlns:p2="http://docs.oasis-open.org/ns/energyinterop/201110/payloads">
+      <p2:requestID>B8A6E0D2D4</p2:requestID>
+      <p1:oadrReport xmlns:p3="urn:ietf:params:xml:ns:icalendar-2.0" xmlns:p4="http://docs.oasis-open.org/ns/energyinterop/201110">
+        <p3:duration>
+          <p3:duration>PT120M</p3:duration>
+        </p3:duration>
+        <p1:oadrReportDescription xmlns:p4="http://docs.oasis-open.org/ns/energyinterop/201110" xmlns:p5="http://docs.oasis-open.org/ns/emix/2011/06/power" xmlns:p6="http://docs.oasis-open.org/ns/emix/2011/06">
+          <p4:rID>rid_energy_4184bb93</p4:rID>
+          <p4:reportDataSource>
+            <p4:resourceID>DEVICE1</p4:resourceID>
+          </p4:reportDataSource>
+          <p4:reportType>reading</p4:reportType>
+          <p5:energyReal xmlns:p6="http://docs.oasis-open.org/ns/emix/2011/06/siscale">
+            <p5:itemDescription/>
+            <p5:itemUnits>Wh</p5:itemUnits>
+            <p6:siScaleCode>none</p6:siScaleCode>
+          </p5:energyReal>
+          <p4:readingType>Direct Read</p4:readingType>
+          <p6:marketContext/>
+          <p1:oadrSamplingRate>
+            <p1:oadrMinPeriod>PT1M</p1:oadrMinPeriod>
+            <p1:oadrMaxPeriod>PT1M</p1:oadrMaxPeriod>
+            <p1:oadrOnChange>false</p1:oadrOnChange>
+          </p1:oadrSamplingRate>
+        </p1:oadrReportDescription>
+        <p1:oadrReportDescription xmlns:p4="http://docs.oasis-open.org/ns/energyinterop/201110" xmlns:p5="http://docs.oasis-open.org/ns/emix/2011/06/power" xmlns:p6="http://docs.oasis-open.org/ns/emix/2011/06">
+          <p4:rID>rid_power_4184bb93</p4:rID>
+          <p4:reportDataSource>
+            <p4:resourceID>DEVICE1</p4:resourceID>
+          </p4:reportDataSource>
+          <p4:reportType>reading</p4:reportType>
+          <p5:powerReal xmlns:p6="http://docs.oasis-open.org/ns/emix/2011/06/siscale">
+            <p5:itemDescription/>
+            <p5:itemUnits>W</p5:itemUnits>
+            <p6:siScaleCode>none</p6:siScaleCode>
+            <p5:powerAttributes>
+              <p5:hertz>60</p5:hertz>
+              <p5:voltage>120</p5:voltage>
+              <p5:ac>true</p5:ac>
+            </p5:powerAttributes>
+          </p5:powerReal>
+          <p4:readingType>Direct Read</p4:readingType>
+          <p6:marketContext/>
+          <p1:oadrSamplingRate>
+            <p1:oadrMinPeriod>PT1M</p1:oadrMinPeriod>
+            <p1:oadrMaxPeriod>PT1M</p1:oadrMaxPeriod>
+            <p1:oadrOnChange>false</p1:oadrOnChange>
+          </p1:oadrSamplingRate>
+        </p1:oadrReportDescription>
+        <p4:reportRequestID>0</p4:reportRequestID>
+        <p4:reportSpecifierID>DEMO_TELEMETRY_USAGE</p4:reportSpecifierID>
+        <p4:reportName>METADATA_TELEMETRY_USAGE</p4:reportName>
+        <p4:createdDateTime>2020-12-15T14:10:32Z</p4:createdDateTime>
+      </p1:oadrReport>
+      <p3:venID>ven_id</p3:venID>
+    </p1:oadrRegisterReport>
+  </p1:oadrSignedObject>
+</p1:oadrPayload>"""
+    server = OpenADRServer(vtn_id='myvtn')
+    await server.run()
+
+
+    # Test with no configured callbacks
+
+    from aiohttp import ClientSession
+    async with ClientSession() as session:
+        async with session.post("http://localhost:8080/OpenADR2/Simple/2.0b/EiReport",
+                                  headers={'content-type': 'Application/XML'},
+                                  data=msg.encode('utf-8')) as resp:
+            assert resp.status == 200
+
+
+    # Test with a working callback
+
+    def report_callback(data):
+        print(data)
+
+    def working_on_register_report(ven_id, resource_id, measurement, unit, scale, min_sampling_interval, max_sampling_interval):
+        return report_callback, min_sampling_interval
+
+    server.add_handler('on_register_report', working_on_register_report)
+    async with ClientSession() as session:
+        async with session.post("http://localhost:8080/OpenADR2/Simple/2.0b/EiReport",
+                                  headers={'content-type': 'Application/XML'},
+                                  data=msg.encode('utf-8')) as resp:
+            assert resp.status == 200
+
+
+    # Test with a broken callback
+
+    def broken_on_register_report(ven_id, resource_id, measurement, unit, scale, min_sampling_interval, max_sampling_interval):
+        return "Hello There"
+
+    server.add_handler('on_register_report', broken_on_register_report)
+    async with ClientSession() as session:
+        async with session.post("http://localhost:8080/OpenADR2/Simple/2.0b/EiReport",
+                                  headers={'content-type': 'Application/XML'},
+                                  data=msg.encode('utf-8')) as resp:
+            assert resp.status == 200
+
+    # assert "Your on_register_report handler must return a tuple; it returned 'Hello There' (str)." in caplog.messages
+
+
+    # Test with a broken full callback
+
+    def broken_on_register_report_full(report):
+        return "Hello There Again"
+
+    server.add_handler('on_register_report', broken_on_register_report_full)
+    async with ClientSession() as session:
+        async with session.post("http://localhost:8080/OpenADR2/Simple/2.0b/EiReport",
+                                  headers={'content-type': 'Application/XML'},
+                                  data=msg.encode('utf-8')) as resp:
+            assert resp.status == 200
+
+    assert f"Your on_register_report handler must return a list of tuples. It returned 'Hello There Again' (str)." in caplog.messages
+
+    await server.stop()
 
 if __name__ == "__main__":
     asyncio.run(test_update_reports())