Pārlūkot izejas kodu

Converted client to use aiohttp for requests
This also means that synchronous mode is (for now) unsupported.
Closes #2 and addresses part of #1.

Stan Janssen 4 gadi atpakaļ
vecāks
revīzija
06aa4b6376
1 mainītis faili ar 29 papildinājumiem un 28 dzēšanām
  1. 29 28
      pyopenadr/client.py

+ 29 - 28
pyopenadr/client.py

@@ -2,7 +2,7 @@
 
 import xmltodict
 import random
-import requests
+import aiohttp
 from jinja2 import Environment, PackageLoader, select_autoescape
 from pyopenadr.utils import parse_message, create_message, new_request_id, peek, generate_id
 from pyopenadr import enums
@@ -30,22 +30,23 @@ class OpenADRClient:
         self.report_requests = {}   # Mapping of the reports requested by the VTN
         self.report_schedulers = {} # Mapping between reportRequestIDs and our internal report schedulers
         self.scheduler = AsyncIOScheduler()
+        self.client_session = aiohttp.ClientSession()
 
-    def run(self):
+    async def run(self):
         """
         Run the client in full-auto mode.
         """
         if not hasattr(self, 'on_event') or not hasattr(self, 'on_report'):
             raise NotImplementedError("You must implement both the on_event and and_report functions or coroutines.")
 
-        self.create_party_registration()
+        await self.create_party_registration()
 
         if not self.ven_id:
             print("No VEN ID received from the VTN, aborting registration.")
             return
 
         if self.reports:
-            self.register_report()
+            await self.register_report()
 
         # Set up automatic polling
         if self.poll_frequency.total_seconds() < 60:
@@ -119,17 +120,17 @@ class OpenADRClient:
         self.reports[report_id] = report
         self.report_ids[resource_id] = {'item_base': measurand}
 
-    def query_registration(self):
+    async def query_registration(self):
         """
         Request information about the VTN.
         """
         request_id = new_request_id()
         service = 'EiRegisterParty'
         message = create_message('oadrQueryRegistration', request_id=request_id)
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
         return response_type, response_payload
 
-    def create_party_registration(self, http_pull_model=True, xml_signature=False,
+    async def create_party_registration(self, http_pull_model=True, xml_signature=False,
                                   report_only=False, profile_name='2.0b',
                                   transport_name='simpleHttp', transport_address=None, ven_id=None):
         request_id = new_request_id()
@@ -144,7 +145,7 @@ class OpenADRClient:
         if ven_id:
             payload['ven_id'] = ven_id
         message = create_message('oadrCreatePartyRegistration', request_id=new_request_id(), **payload)
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
         if response_payload['response']['response_code'] != 200:
             status_code = response_payload['response']['response_code']
             status_description = response_payload['response']['response_description']
@@ -156,10 +157,10 @@ class OpenADRClient:
         print(f"The polling frequency is {self.poll_frequency}")
         return response_type, response_payload
 
-    def cancel_party_registration(self):
+    async def cancel_party_registration(self):
         raise NotImplementedError("Cancel Registration is not yet implemented")
 
-    def request_event(self, reply_limit=1):
+    async def request_event(self, reply_limit=1):
         """
         Request the next Event from the VTN, if it has any.
         """
@@ -168,10 +169,10 @@ class OpenADRClient:
                    'reply_limit': reply_limit}
         message = create_message('oadrRequestEvent', **payload)
         service = 'EiEvent'
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
         return response_type, response_payload
 
-    def created_event(self, request_id, event_id, opt_type, modification_number=1):
+    async def created_event(self, request_id, event_id, opt_type, modification_number=1):
         """
         Inform the VTN that we created an event.
         """
@@ -187,10 +188,10 @@ class OpenADRClient:
                                         'modification_number': modification_number,
                                         'opt_type': opt_type}]}
         message = create_message('oadrCreatedEvent', **payload)
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
         return response_type, response_payload
 
-    def register_report(self):
+    async def register_report(self):
         """
         Tell the VTN about our reporting capabilities.
         """
@@ -202,19 +203,19 @@ class OpenADRClient:
 
         service = 'EiReport'
         message = create_message('oadrRegisterReport', **payload)
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
 
         # Remember which reports the VTN is interested in
 
         return response_type, response_payload
 
-    def created_report(self):
+    async def created_report(self):
         pass
 
-    def poll(self):
+    async def poll(self):
         service = 'OadrPoll'
         message = create_message('oadrPoll', ven_id=self.ven_id)
-        response_type, response_payload = self._perform_request(service, message)
+        response_type, response_payload = await self._perform_request(service, message)
         return response_type, response_payload
 
     async def update_report(self, report_id, resource_id=None):
@@ -255,17 +256,17 @@ class OpenADRClient:
             print("TODO: cancel this report")
 
 
-    def _perform_request(self, service, message):
+    async def _perform_request(self, service, message):
         if self.debug:
             print(f"Sending {message}")
         url = f"{self.vtn_url}/{service}"
-        r = requests.post(url,
-                          data=message)
-        if r.status_code != HTTPStatus.OK:
-            raise Exception(f"Received non-OK status in request: {r.status_code}")
-        if self.debug:
-            print(r.content.decode('utf-8'))
-        return parse_message(r.content)
+        async with self.client_session.post(url, data=message) as req:
+            if req.status != HTTPStatus.OK:
+                raise Exception(f"Received non-OK status in request: {req.status_code}")
+            content = await req.read()
+            if self.debug:
+                print(content.decode('utf-8'))
+        return parse_message(content)
 
     async def _on_event(self, message):
         if self.debug:
@@ -288,13 +289,13 @@ class OpenADRClient:
         return result
 
     async def _poll(self):
-        response_type, response_payload = self.poll()
+        response_type, response_payload = await self.poll()
         if response_type == 'oadrResponse':
             print("No events or reports available")
             return
 
         if response_type == 'oadrRequestReregistration':
-            result = self.create_party_registration()
+            result = await self.create_party_registration()
 
         if response_type == 'oadrDistributeEvent':
             result = await self._on_event(response_payload)