Prechádzať zdrojové kódy

Add await_if_required and gather_if_required utilities

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 3 rokov pred
rodič
commit
d52e6be901
2 zmenil súbory, kde vykonal 61 pridanie a 0 odobranie
  1. 19 0
      openleadr/utils.py
  2. 42 0
      test/test_utils.py

+ 19 - 0
openleadr/utils.py

@@ -735,6 +735,25 @@ def validate_report_request_tuples(list_of_report_requests, full_mode=False):
                                  f"It returned: '{rrq[1:]}'. The third element was not of type timedelta.")
 
 
+async def await_if_required(result):
+    if asyncio.iscoroutine(result):
+        result = await result
+    return result
+
+
+async def gather_if_required(results):
+    if results is None:
+        return results
+    if len(results) > 0:
+        if not any([asyncio.iscoroutine(r) for r in results]):
+            results = results
+        elif all([asyncio.iscoroutine(r) for r in results]):
+            results = await asyncio.gather(*results)
+        else:
+            results = [await await_if_required(result) for result in results]
+    return results
+
+
 def order_events(events, limit=None, offset=None):
     """
     Order the events according to the OpenADR rules:

+ 42 - 0
test/test_utils.py

@@ -290,6 +290,48 @@ def test_parse_datetime():
     assert utils.parse_datetime("2020-12-15T11:29:34.123456Z") == datetime(2020, 12, 15, 11, 29, 34, 123456, tzinfo=timezone.utc)
     assert utils.parse_datetime("2020-12-15T11:29:34.123Z") == datetime(2020, 12, 15, 11, 29, 34, 123000, tzinfo=timezone.utc)
     assert utils.parse_datetime("2020-12-15T11:29:34.123456789Z") == datetime(2020, 12, 15, 11, 29, 34, 123456, tzinfo=timezone.utc)
+
+@pytest.mark.asyncio
+async def test_await_if_required():
+    def normal_func():
+        return 123
+
+    async def coro_func():
+        return 456
+
+    result = await utils.await_if_required(normal_func())
+    assert result == 123
+
+    result = await utils.await_if_required(coro_func())
+    assert result == 456
+
+    result = await utils.await_if_required(None)
+    assert result == None
+
+@pytest.mark.asyncio
+async def test_gather_if_required():
+    def normal_func():
+        return 123
+
+    async def coro_func():
+        return 456
+
+    raw_results = [normal_func(), normal_func(), normal_func()]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [123, 123, 123]
+
+    raw_results = [coro_func(), coro_func(), coro_func()]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [456, 456, 456]
+
+    raw_results = [coro_func(), normal_func(), None]
+    results = await utils.gather_if_required(raw_results)
+    assert results == [456, 123, None]
+
+    raw_results = []
+    results = await utils.gather_if_required(raw_results)
+    assert results == []
+
 def test_order_events():
     now = datetime.now(timezone.utc)
     event_1_active_high_prio = objects.Event(event_descriptor=objects.EventDescriptor(event_id='event001',