Selaa lähdekoodia

Implement request_event for the VTN Server

This allows the VEN to request the next event, and it will only return the next event, and no other messages in between.

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 vuotta sitten
vanhempi
commit
486a6224de

+ 5 - 10
openleadr/server.py

@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import asyncio
 from aiohttp import web
 from openleadr.service import EventService, PollService, RegistrationService, ReportService, \
                               OptService, VTNService
@@ -23,6 +22,7 @@ from openleadr import objects
 from openleadr import utils
 from functools import partial
 from datetime import datetime, timedelta, timezone
+from collections import deque
 import logging
 import ssl
 import re
@@ -221,8 +221,8 @@ class OpenADRServer:
                               event_signals=[event_signal],
                               targets=targets)
         if ven_id not in self.message_queues:
-            self.message_queues[ven_id] = asyncio.Queue()
-        self.message_queues[ven_id].put_nowait(event)
+            self.message_queues[ven_id] = deque()
+        self.message_queues[ven_id].append(event)
         self.services['event_service'].pending_events[event_id] = (event, callback)
         return event_id
 
@@ -234,13 +234,8 @@ class OpenADRServer:
                            that contains the event details.
         """
         if ven_id not in self.message_queues:
-            self.message_queues[ven_id] = asyncio.Queue()
-        self.message_queues[ven_id].put_nowait(event)
-
-    async def request_report(self):
-        """
-        Request a report from the client.
-        """
+            self.message_queues[ven_id] = deque()
+        self.message_queues[ven_id].append(event)
 
     def add_handler(self, name, func):
         """

+ 20 - 12
openleadr/service/event_service.py

@@ -20,6 +20,7 @@ from openleadr import objects, utils, enums
 import logging
 from datetime import datetime, timezone
 from functools import partial
+from dataclasses import asdict
 logger = logging.getLogger('openleadr')
 
 
@@ -38,19 +39,26 @@ class EventService(VTNService):
         """
         The VEN requests us to send any events we have.
         """
-        result = self.on_request_event(payload['ven_id'])
-        if asyncio.iscoroutine(result):
-            result = await result
+        if self.polling_method == 'external':
+            result = self.on_request_event(ven_id=payload['ven_id'])
+            if asyncio.iscoroutine(result):
+                result = await result
+        elif payload['ven_id'] in self.message_queues:
+            queue = self.message_queues[payload['ven_id']]
+            result = utils.get_next_event_from_deque(queue)
+        else:
+            return 'oadrResponse', {}
+
         if result is None:
-            return 'oadrDistributeEvent', {'events': []}
-        if isinstance(result, dict):
+            return 'oadrResponse', {}
+        if isinstance(result, dict) and 'event_descriptor' in result:
             return 'oadrDistributeEvent', {'events': [result]}
-        if isinstance(result, objects.Event):
-            return 'oadrDistributeEvent', {'events': [result]}
-        if isinstance(result, list):
-            return 'oadrDistributeEvent', {'events': result}
-        else:
-            raise TypeError("Event handler should return None, a dict or a list")
+        elif isinstance(result, objects.Event):
+            return 'oadrDistributeEvent', {'events': [asdict(result)]}
+
+        logger.warning("Could not determine type of message "
+                       f"in response to oadrRequestEvent: {result}")
+        return 'oadrResponse', result
 
     def on_request_event(self, ven_id):
         """
@@ -125,4 +133,4 @@ class EventService(VTNService):
         event.event_descriptor.event_status = event_status
         if event_status == enums.EVENT_STATUS.CANCELLED:
             event.event_descriptor.modification_number += 1
-        self.message_queues[ven_id].put_nowait(event)
+        self.message_queues[ven_id].append(event)

+ 2 - 2
openleadr/service/poll_service.py

@@ -117,8 +117,8 @@ class PollService(VTNService):
             result = self.on_poll(ven_id=payload['ven_id'])
         elif payload['ven_id'] in self.message_queues:
             try:
-                result = self.message_queues[payload['ven_id']].get_nowait()
-            except asyncio.QueueEmpty:
+                result = self.message_queues[payload['ven_id']].popleft()
+            except IndexError:
                 return 'oadrResponse', {}
         else:
             return 'oadrResponse', {}

+ 3 - 0
openleadr/service/vtn_service.py

@@ -123,6 +123,8 @@ class VTNService:
                 response_type, response_payload = result
                 if is_dataclass(response_payload):
                     response_payload = asdict(response_payload)
+                elif response_payload is None:
+                    response_payload = {}
             else:
                 response_type, response_payload = 'oadrResponse', {}
 
@@ -141,6 +143,7 @@ class VTNService:
                                                                   "A message of type "
                                                                   f"{message_type} should not be "
                                                                   "sent to this endpoint")
+        logger.info(f"Responding to {message_type} with a {response_type} message: {response_payload}.")
         return response_type, response_payload
 
     def error_response(self, message_type, error_code, error_description):

+ 18 - 1
openleadr/utils.py

@@ -17,7 +17,7 @@
 from datetime import datetime, timedelta, timezone
 from dataclasses import is_dataclass, asdict
 from collections import OrderedDict
-from openleadr import enums
+from openleadr import enums, objects
 import asyncio
 import itertools
 import re
@@ -664,3 +664,20 @@ def setmember(obj, member, value):
         setattr(obj, member, value)
     else:
         obj[member] = value
+
+
+def get_next_event_from_deque(deque):
+    unused_elements = []
+    event = None
+    for i in range(len(deque)):
+        try:
+            msg = deque.popleft()
+            if isinstance(msg, objects.Event) or (isinstance(msg, dict) and 'event_descriptor' in msg):
+                event = msg
+                break
+            else:
+                unused_elements.append(msg)
+        except IndexError:
+            pass
+    deque.extend(unused_elements)
+    return event

+ 28 - 0
test/test_queues.py

@@ -169,3 +169,31 @@ async def test_event_status_opt_in_with_ramp_up():
     await client.stop()
     await server.stop()
     await asyncio.sleep(1)
+
+
+@pytest.mark.asyncio
+async def test_request_event():
+    loop = asyncio.get_event_loop()
+    client = OpenADRClient(ven_name='myven',
+                           vtn_url='http://localhost:8080/OpenADR2/Simple/2.0b')
+
+    server = OpenADRServer(vtn_id='myvtn', requested_poll_freq=datetime.timedelta(seconds=1))
+    server.add_handler('on_create_party_registration', on_create_party_registration)
+
+    event_id = server.add_event(ven_id='ven123',
+                                signal_name='simple',
+                                signal_type='level',
+                                intervals=[{'dtstart': datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=4),
+                                            'duration': datetime.timedelta(seconds=2),
+                                            'signal_payload': 1}],
+                                ramp_up_period=datetime.timedelta(seconds=2),
+                                callback=partial(event_callback))
+
+    assert server.services['event_service'].pending_events[event_id][0].event_descriptor.event_status == 'far'
+    await server.run_async()
+    await client.create_party_registration()
+    message_type, message_payload = await client.request_event()
+    assert message_type == 'oadrDistributeEvent'
+    message_type, message_payload = await client.request_event()
+    assert message_type == 'oadrResponse'
+    await server.stop()

+ 45 - 0
test/test_utils.py

@@ -2,6 +2,7 @@ from openleadr import utils, objects
 from dataclasses import dataclass
 import pytest
 from datetime import datetime, timezone, timedelta
+from collections import deque
 
 @dataclass
 class dc:
@@ -110,3 +111,47 @@ def test_cron_config():
                                                                                 'hour': '*',
                                                                                 'jitter': 1}
 
+def test_get_event_from_deque():
+    d = deque()
+    now = datetime.now(timezone.utc)
+    event1 = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event123',
+                                                                    event_status='far',
+                                                                    modification_number='1',
+                                                                    market_context='http://marketcontext01'),
+                           event_signals=[objects.EventSignal(signal_name='simple',
+                                                              signal_type='level',
+                                                              signal_id=utils.generate_id(),
+                                                              intervals=[objects.Interval(dtstart=now,
+                                                                                          duration=timedelta(minutes=10),
+                                                                                          signal_payload=1)])],
+                            targets=[{'ven_id': 'ven123'}])
+    msg_one = {'message': 'one'}
+    msg_two = {'message': 'two'}
+    msg_three = {'message': 'three'}
+    event2 = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event123',
+                                                                    event_status='far',
+                                                                    modification_number='1',
+                                                                    market_context='http://marketcontext01'),
+                           event_signals=[objects.EventSignal(signal_name='simple',
+                                                              signal_type='level',
+                                                              signal_id=utils.generate_id(),
+                                                              intervals=[objects.Interval(dtstart=now,
+                                                                                          duration=timedelta(minutes=10),
+                                                                                          signal_payload=1)])],
+                            targets=[{'ven_id': 'ven123'}])
+
+    d.append(event1)
+    d.append(msg_one)
+    d.append(msg_two)
+    d.append(msg_three)
+    d.append(event2)
+    assert utils.get_next_event_from_deque(d) is event1
+    assert utils.get_next_event_from_deque(d) is event2
+    assert utils.get_next_event_from_deque(d) is None
+    assert utils.get_next_event_from_deque(d) is None
+    assert len(d) == 3
+    assert d.popleft() is msg_one
+    assert d.popleft() is msg_two
+    assert d.popleft() is msg_three
+
+