Kaynağa Gözat

Use actual UUIDs, improve documentation of utilities

Stan Janssen 3 yıl önce
ebeveyn
işleme
f882b7ddaa
2 değiştirilmiş dosya ile 42 ekleme ve 11 silme
  1. 5 5
      openleadr/client.py
  2. 37 6
      openleadr/utils.py

+ 5 - 5
openleadr/client.py

@@ -21,7 +21,7 @@ OpenADR Client for Python
 import xmltodict
 import random
 import aiohttp
-from openleadr.utils import new_request_id, peek, generate_id, certificate_fingerprint
+from openleadr.utils import peek, generate_id, certificate_fingerprint
 from openleadr.messaging import create_message, parse_message
 from openleadr import enums
 from datetime import datetime, timedelta, timezone
@@ -180,7 +180,7 @@ class OpenADRClient:
         """
         Request information about the VTN.
         """
-        request_id = new_request_id()
+        request_id = generate_id()
         service = 'EiRegisterParty'
         message = self._create_message('oadrQueryRegistration', request_id=request_id)
         response_type, response_payload = await self._perform_request(service, message)
@@ -200,7 +200,7 @@ class OpenADRClient:
         :param str transport_address: Which public-facing address the server should use to communicate.
         :param str ven_id: The ID for this VEN. If you leave this blank, a VEN_ID will be assigned by the VTN.
         """
-        request_id = new_request_id()
+        request_id = generate_id()
         service = 'EiRegisterParty'
         payload = {'ven_name': self.ven_name,
                    'http_pull_model': http_pull_model,
@@ -211,7 +211,7 @@ class OpenADRClient:
                    'transport_address': transport_address}
         if ven_id:
             payload['ven_id'] = ven_id
-        message = self._create_message('oadrCreatePartyRegistration', request_id=new_request_id(), **payload)
+        message = self._create_message('oadrCreatePartyRegistration', request_id=generate_id(), **payload)
         response_type, response_payload = await self._perform_request(service, message)
         if response_type is None:
             return
@@ -233,7 +233,7 @@ class OpenADRClient:
         """
         Request the next Event from the VTN, if it has any.
         """
-        payload = {'request_id': new_request_id(),
+        payload = {'request_id': generate_id(),
                    'ven_id': self.ven_id,
                    'reply_limit': reply_limit}
         message = self._create_message('oadrRequestEvent', **payload)

+ 37 - 6
openleadr/utils.py

@@ -16,6 +16,7 @@
 
 from asyncio import iscoroutine
 from datetime import datetime, timedelta, timezone
+from dataclasses import is_dataclass, asdict
 import random
 import string
 from collections import OrderedDict
@@ -23,17 +24,18 @@ import itertools
 import re
 import ssl
 import hashlib
+import uuid
 
 from openleadr import config
 
 DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
 DATETIME_FORMAT_NO_MICROSECONDS = "%Y-%m-%dT%H:%M:%SZ"
 
-def new_request_id(*args, **kwargs):
-    return random.choice(string.ascii_lowercase) + ''.join(random.choice(string.hexdigits) for _ in range(9)).lower()
-
 def generate_id(*args, **kwargs):
-    return new_request_id()
+    """
+    Generate a string that can be used as an identifier in OpenADR messages.
+    """
+    return str(uuid.uuid4())
 
 def indent_xml(message):
     """
@@ -53,6 +55,9 @@ def indent_xml(message):
     return "\n".join(lines)
 
 def flatten_xml(message):
+    """
+    Flatten the entire XML structure.
+    """
     lines = [line.strip() for line in message.split("\n") if line.strip() != ""]
     for line in lines:
         line = re.sub(r'\n', '', line)
@@ -61,8 +66,14 @@ def flatten_xml(message):
 
 def normalize_dict(ordered_dict):
     """
-    Convert the OrderedDict to a regular dict, snake_case the key names, and promote uniform lists.
+    Main conversion function for the output of xmltodict to the OpenLEADR
+    representation of OpenADR contents.
+
+    :param ordered_dict dict: The OrderedDict, dict or dataclass that you wish to convert.
     """
+    if is_dataclass(ordered_dict):
+        ordered_dict = asdict(ordered_dict)
+
     def normalize_key(key):
         if key.startswith('oadr'):
             key = key[4:]
@@ -231,7 +242,7 @@ def normalize_dict(ordered_dict):
 
 def parse_datetime(value):
     """
-    Parse an ISO8601 datetime
+    Parse an ISO8601 datetime into a datetime.datetime object.
     """
     matches = re.match(r'(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})\.?(\d{1,6})?\d*Z', value)
     if matches:
@@ -298,6 +309,9 @@ def peek(iterable):
         return itertools.chain([first], iterable)
 
 def datetimeformat(value, format=DATETIME_FORMAT):
+    """
+    Format a given datetime as a UTC ISO3339 string.
+    """
     if not isinstance(value, datetime):
         return value
     return value.astimezone(timezone.utc).strftime(format)
@@ -339,6 +353,9 @@ def booleanformat(value):
         raise ValueError(f"A boolean value must be provided, not {value}.")
 
 def ensure_bytes(obj):
+    """
+    Converts a utf-8 str object to bytes.
+    """
     if isinstance(obj, bytes):
         return obj
     if isinstance(obj, str):
@@ -347,6 +364,9 @@ def ensure_bytes(obj):
         raise TypeError("Must be bytes or str")
 
 def ensure_str(obj):
+    """
+    Converts bytes to a utf-8 string.
+    """
     if isinstance(obj, str):
         return obj
     if isinstance(obj, bytes):
@@ -355,10 +375,21 @@ def ensure_str(obj):
         raise TypeError("Must be bytes or str")
 
 def certificate_fingerprint(certificate_str):
+    """
+    Calculate the fingerprint for the given certificate, as defined by OpenADR.
+    """
     der_cert = ssl.PEM_cert_to_DER_cert(ensure_str(certificate_str))
     hash = hashlib.sha256(der_cert).digest().hex()
     return ":".join([hash[i-2:i].upper() for i in range(-20, 0, 2)])
 
 def extract_pem_cert(tree):
+    """
+    Extract a given X509 certificate inside an XML tree and return the standard
+    form of a PEM-encoded certificate.
+
+    :param tree lxml.etree: The tree that contains the X509 element. This is
+                            usually the KeyInfo element from the XMLDsig Signature
+                            part of the message.
+    """
     cert = tree.find('.//{http://www.w3.org/2000/09/xmldsig#}X509Certificate').text
     return "-----BEGIN CERTIFICATE-----\n" + cert + "-----END CERTIFICATE-----\n"