Browse Source

Silently cancel running tasks on client and server stop

This prevents a "Task was destroyed but it is pending!" after shutting down the client or server.

Signed-off-by: Stan Janssen <stan.janssen@elaad.nl>
Stan Janssen 4 years ago
parent
commit
ad519b6131
4 changed files with 33 additions and 14 deletions
  1. 8 2
      openleadr/client.py
  2. 7 0
      openleadr/server.py
  3. 6 3
      openleadr/service/event_service.py
  4. 12 9
      openleadr/utils.py

+ 8 - 2
openleadr/client.py

@@ -151,7 +151,11 @@ class OpenADRClient:
             self.scheduler.shutdown()
             self.scheduler.shutdown()
         if self.report_queue_task:
         if self.report_queue_task:
             self.report_queue_task.cancel()
             self.report_queue_task.cancel()
+        delayed_call_tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith('DelayedCall')]
+        for task in delayed_call_tasks:
+            task.cancel()
         await self.client_session.close()
         await self.client_session.close()
+        await asyncio.sleep(0)
 
 
     def add_handler(self, handler, callback):
     def add_handler(self, handler, callback):
         """
         """
@@ -600,7 +604,8 @@ class OpenADRClient:
                 if self.allow_jitter:
                 if self.allow_jitter:
                     delay = random.uniform(0, min(30, report_interval / 2))
                     delay = random.uniform(0, min(30, report_interval / 2))
                     self.loop.create_task(utils.delayed_call(func=self.pending_reports.put(outgoing_report),
                     self.loop.create_task(utils.delayed_call(func=self.pending_reports.put(outgoing_report),
-                                                             delay=delay))
+                                                             delay=delay),
+                                          name=f'DelayedCall-{utils.generate_id()}')
                 else:
                 else:
                     await self.pending_reports.put(self.incomplete_reports.pop(report_request_id))
                     await self.pending_reports.put(self.incomplete_reports.pop(report_request_id))
             else:
             else:
@@ -611,7 +616,8 @@ class OpenADRClient:
             if self.allow_jitter:
             if self.allow_jitter:
                 delay = random.uniform(0, min(30, granularity.total_seconds() / 2))
                 delay = random.uniform(0, min(30, granularity.total_seconds() / 2))
                 self.loop.create_task(utils.delayed_call(func=self.pending_reports.put(outgoing_report),
                 self.loop.create_task(utils.delayed_call(func=self.pending_reports.put(outgoing_report),
-                                                         delay=delay))
+                                                         delay=delay),
+                                      name=f'DelayedCall-{utils.generate_id()}')
             else:
             else:
                 await self.pending_reports.put(outgoing_report)
                 await self.pending_reports.put(outgoing_report)
 
 

+ 7 - 0
openleadr/server.py

@@ -23,6 +23,7 @@ from openleadr import utils
 from functools import partial
 from functools import partial
 from datetime import datetime, timedelta, timezone
 from datetime import datetime, timedelta, timezone
 from collections import deque
 from collections import deque
+import asyncio
 import logging
 import logging
 import ssl
 import ssl
 import re
 import re
@@ -147,7 +148,13 @@ class OpenADRServer:
         print("*" * 80)
         print("*" * 80)
         print("")
         print("")
 
 
+    async def run_async(self):
+        await self.run()
+
     async def stop(self):
     async def stop(self):
+        delayed_call_tasks = [task for task in asyncio.all_tasks() if task.get_name().startswith('DelayedCall')]
+        for task in delayed_call_tasks:
+            task.cancel()
         await self.app_runner.cleanup()
         await self.app_runner.cleanup()
 
 
     def add_event(self, ven_id, signal_name, signal_type, intervals, callback, targets=None,
     def add_event(self, ven_id, signal_name, signal_type, intervals, callback, targets=None,

+ 6 - 3
openleadr/service/event_service.py

@@ -94,17 +94,20 @@ class EventService(VTNService):
                         if active_period.ramp_up_period is not None and event.event_descriptor.event_status == 'far':
                         if active_period.ramp_up_period is not None and event.event_descriptor.event_status == 'far':
                             ramp_up_start_delay = (active_period.dtstart - active_period.ramp_up_period) - now
                             ramp_up_start_delay = (active_period.dtstart - active_period.ramp_up_period) - now
                             update_coro = partial(self._update_event_status, ven_id, event, 'near')
                             update_coro = partial(self._update_event_status, ven_id, event, 'near')
-                            loop.create_task(utils.delayed_call(func=update_coro, delay=ramp_up_start_delay))
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=ramp_up_start_delay),
+                                             name=f'DelayedCall-{utils.generate_id()}')
                         # Schedule status update to 'active'
                         # Schedule status update to 'active'
                         if event.event_descriptor.event_status in ('near', 'far'):
                         if event.event_descriptor.event_status in ('near', 'far'):
                             active_start_delay = active_period.dtstart - now
                             active_start_delay = active_period.dtstart - now
                             update_coro = partial(self._update_event_status, ven_id, event, 'active')
                             update_coro = partial(self._update_event_status, ven_id, event, 'active')
-                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_start_delay))
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_start_delay),
+                                             name=f'DelayedCall-{utils.generate_id()}')
                         # Schedule status update to 'completed'
                         # Schedule status update to 'completed'
                         if event.event_descriptor.event_status in ('near', 'far', 'active'):
                         if event.event_descriptor.event_status in ('near', 'far', 'active'):
                             active_end_delay = active_period.dtstart + active_period.duration - now
                             active_end_delay = active_period.dtstart + active_period.duration - now
                             update_coro = partial(self._update_event_status, ven_id, event, 'completed')
                             update_coro = partial(self._update_event_status, ven_id, event, 'completed')
-                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_end_delay))
+                            loop.create_task(utils.delayed_call(func=update_coro, delay=active_end_delay),
+                                             name=f'DelayedCall-{utils.generate_id()}')
                 elif event_response['event_id'] in self.running_events:
                 elif event_response['event_id'] in self.running_events:
                     event, callback = self.running_events.pop(event_id)
                     event, callback = self.running_events.pop(event_id)
                     result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)
                     result = callback(ven_id=ven_id, event_id=event_id, opt_type=opt_type)

+ 12 - 9
openleadr/utils.py

@@ -582,15 +582,18 @@ def determine_event_status(active_period):
 
 
 
 
 async def delayed_call(func, delay):
 async def delayed_call(func, delay):
-    if isinstance(delay, timedelta):
-        delay = delay.total_seconds()
-    await asyncio.sleep(delay)
-    if asyncio.iscoroutinefunction(func):
-        await func()
-    elif asyncio.iscoroutine(func):
-        await func
-    else:
-        func()
+    try:
+        if isinstance(delay, timedelta):
+            delay = delay.total_seconds()
+        await asyncio.sleep(delay)
+        if asyncio.iscoroutinefunction(func):
+            await func()
+        elif asyncio.iscoroutine(func):
+            await func
+        else:
+            func()
+    except asyncio.CancelledError:
+        pass
 
 
 
 
 def hasmember(obj, member):
 def hasmember(obj, member):