vtn_service.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Copyright 2020 Contributors to OpenLEADR
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. from asyncio import iscoroutine
  13. from http import HTTPStatus
  14. import logging
  15. from aiohttp import web
  16. from lxml.etree import XMLSyntaxError
  17. from signxml.exceptions import InvalidSignature
  18. from .. import errors
  19. from ..enums import STATUS_CODES
  20. from ..messaging import parse_message, validate_xml_schema, authenticate_message
  21. from ..utils import generate_id, get_cert_fingerprint_from_request
  22. from dataclasses import is_dataclass, asdict
  23. logger = logging.getLogger('openleadr')
  24. class VTNService:
  25. def __init__(self, vtn_id):
  26. self.vtn_id = vtn_id
  27. self.handlers = {}
  28. for method in [getattr(self, attr) for attr in dir(self) if callable(getattr(self, attr))]:
  29. if hasattr(method, '__message_type__'):
  30. self.handlers[method.__message_type__] = method
  31. async def handler(self, request):
  32. """
  33. Handle all incoming POST requests.
  34. """
  35. try:
  36. # Check the Content-Type header
  37. content_type = request.headers.get('content-type', '')
  38. if not content_type.lower().startswith("application/xml"):
  39. raise errors.HTTPError(response_code=HTTPStatus.BAD_REQUEST,
  40. response_description="The Content-Type header must be application/xml, "
  41. "you provided {request.headers.get('content-type', '')}")
  42. content = await request.read()
  43. # Validate the message to the XML Schema
  44. message_tree = validate_xml_schema(content)
  45. # Parse the message to a type and payload dict
  46. message_type, message_payload = parse_message(content)
  47. # Authenticate the message
  48. if request.secure and 'ven_id' in message_payload:
  49. await authenticate_message(request, message_tree, message_payload,
  50. self.fingerprint_lookup)
  51. # Pass the message off to the handler and get the response type and payload
  52. try:
  53. # Add the request fingerprint to the message so that the handler can check for it.
  54. if request.secure and message_type == 'oadrCreatePartyRegistration':
  55. message_payload['fingerprint'] = get_cert_fingerprint_from_request(request)
  56. response_type, response_payload = await self.handle_message(message_type,
  57. message_payload)
  58. except Exception as err:
  59. logger.error("An exception occurred during the execution of your handler: "
  60. f"{err.__class__.__name__}: {err}")
  61. raise err
  62. if 'response' not in response_payload:
  63. response_payload['response'] = {'response_code': 200,
  64. 'response_description': 'OK',
  65. 'request_id': message_payload.get('request_id')}
  66. response_payload['vtn_id'] = self.vtn_id
  67. if 'ven_id' not in response_payload:
  68. response_payload['ven_id'] = message_payload.get('ven_id')
  69. except errors.ProtocolError as err:
  70. # In case of an OpenADR error, return a valid OpenADR message
  71. response_type, response_payload = self.error_response(message_type,
  72. err.response_code,
  73. err.response_description)
  74. msg = self._create_message(response_type, **response_payload)
  75. response = web.Response(text=msg,
  76. status=HTTPStatus.OK,
  77. content_type='application/xml')
  78. except errors.HTTPError as err:
  79. # If we throw a http-related error, deal with it here
  80. response = web.Response(text=err.response_description,
  81. status=err.response_code)
  82. except XMLSyntaxError as err:
  83. logger.warning(f"XML schema validation of incoming message failed: {err}.")
  84. response = web.Response(text=f'XML failed validation: {err}',
  85. status=HTTPStatus.BAD_REQUEST)
  86. except errors.FingerprintMismatch as err:
  87. logger.warning(err)
  88. response = web.Response(text=str(err),
  89. status=HTTPStatus.FORBIDDEN)
  90. except InvalidSignature:
  91. logger.warning("Incoming message had invalid signature, ignoring.")
  92. response = web.Response(text='Invalid Signature',
  93. status=HTTPStatus.FORBIDDEN)
  94. except Exception as err:
  95. # In case of some other error, return a HTTP 500
  96. logger.error(f"The VTN server encountered an error: {err.__class__.__name__}: {err}")
  97. response = web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
  98. else:
  99. # We've successfully handled this message
  100. msg = self._create_message(response_type, **response_payload)
  101. response = web.Response(text=msg,
  102. status=HTTPStatus.OK,
  103. content_type='application/xml')
  104. return response
  105. async def handle_message(self, message_type, message_payload):
  106. if message_type in self.handlers:
  107. handler = self.handlers[message_type]
  108. result = handler(message_payload)
  109. if iscoroutine(result):
  110. result = await result
  111. if result is not None:
  112. response_type, response_payload = result
  113. if is_dataclass(response_payload):
  114. response_payload = asdict(response_payload)
  115. elif response_payload is None:
  116. response_payload = {}
  117. else:
  118. response_type, response_payload = 'oadrResponse', {}
  119. response_payload['vtn_id'] = self.vtn_id
  120. if 'ven_id' in message_payload:
  121. response_payload['ven_id'] = message_payload['ven_id']
  122. response_payload['response'] = {'request_id': message_payload.get('request_id', None),
  123. 'response_code': 200,
  124. 'response_description': 'OK'}
  125. response_payload['request_id'] = generate_id()
  126. else:
  127. response_type, response_payload = self.error_response('oadrResponse',
  128. STATUS_CODES.COMPLIANCE_ERROR,
  129. "A message of type "
  130. f"{message_type} should not be "
  131. "sent to this endpoint")
  132. logger.info(f"Responding to {message_type} with a {response_type} message: {response_payload}.")
  133. return response_type, response_payload
  134. def error_response(self, message_type, error_code, error_description):
  135. if message_type == 'oadrCreatePartyRegistration':
  136. response_type = 'oadrCreatedPartyRegistration'
  137. if message_type == 'oadrRequestEvent':
  138. response_type = 'oadrDistributeEvent'
  139. else:
  140. response_type = 'oadrResponse'
  141. response_payload = {'response': {'response_code': error_code,
  142. 'response_description': error_description}}
  143. return response_type, response_payload