vtn_service.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Copyright 2020 Contributors to OpenLEADR
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. from asyncio import iscoroutine
  13. from http import HTTPStatus
  14. import logging
  15. from aiohttp import web
  16. from lxml.etree import XMLSyntaxError
  17. from signxml.exceptions import InvalidSignature
  18. from .. import errors
  19. from ..enums import STATUS_CODES
  20. from ..messaging import parse_message, validate_xml_schema, authenticate_message
  21. from ..utils import generate_id, get_cert_fingerprint_from_request
  22. from dataclasses import is_dataclass, asdict
  23. logger = logging.getLogger('openleadr')
  24. class VTNService:
  25. def __init__(self, vtn_id):
  26. self.vtn_id = vtn_id
  27. self.handlers = {}
  28. for method in [getattr(self, attr) for attr in dir(self) if callable(getattr(self, attr))]:
  29. if hasattr(method, '__message_type__'):
  30. self.handlers[method.__message_type__] = method
  31. async def handler(self, request):
  32. """
  33. Handle all incoming POST requests.
  34. """
  35. content = await request.read()
  36. try:
  37. # Validate the message to the XML Schema
  38. message_tree = validate_xml_schema(content)
  39. # Parse the message to a type and payload dict
  40. message_type, message_payload = parse_message(content)
  41. # Authenticate the message
  42. if request.secure and 'ven_id' in message_payload:
  43. await authenticate_message(request, message_tree, message_payload,
  44. self.fingerprint_lookup)
  45. # Pass the message off to the handler and get the response type and payload
  46. try:
  47. # Add the request fingerprint to the message so that the handler can check for it.
  48. if request.secure and message_type == 'oadrCreatePartyRegistration':
  49. message_payload['fingerprint'] = get_cert_fingerprint_from_request(request)
  50. response_type, response_payload = await self.handle_message(message_type,
  51. message_payload)
  52. except Exception as err:
  53. logger.error("An exception occurred during the execution of your handler: "
  54. f"{err.__class__.__name__}: {err}")
  55. raise err
  56. if 'response' not in response_payload:
  57. response_payload['response'] = {'response_code': 200,
  58. 'response_description': 'OK',
  59. 'request_id': message_payload.get('request_id')}
  60. response_payload['vtn_id'] = self.vtn_id
  61. if 'ven_id' not in response_payload:
  62. response_payload['ven_id'] = message_payload.get('ven_id')
  63. except errors.ProtocolError as err:
  64. # In case of an OpenADR error, return a valid OpenADR message
  65. response_type, response_payload = self.error_response(message_type,
  66. err.response_code,
  67. err.response_description)
  68. msg = self._create_message(response_type, **response_payload)
  69. response = web.Response(text=msg,
  70. status=HTTPStatus.OK,
  71. content_type='application/xml')
  72. except errors.HTTPError as err:
  73. # If we throw a http-related error, deal with it here
  74. response = web.Response(text=err.response_description,
  75. status=err.response_code)
  76. except XMLSyntaxError as err:
  77. logger.warning(f"XML schema validation of incoming message failed: {err}.")
  78. response = web.Response(text=f'XML failed validation: {err}',
  79. status=HTTPStatus.BAD_REQUEST)
  80. except errors.FingerprintMismatch as err:
  81. logger.warning(err)
  82. response = web.Response(text=str(err),
  83. status=HTTPStatus.FORBIDDEN)
  84. except InvalidSignature:
  85. logger.warning("Incoming message had invalid signature, ignoring.")
  86. response = web.Response(text='Invalid Signature',
  87. status=HTTPStatus.FORBIDDEN)
  88. except Exception as err:
  89. # In case of some other error, return a HTTP 500
  90. logger.error(f"The VTN server encountered an error: {err.__class__.__name__}: {err}")
  91. response = web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)
  92. else:
  93. # We've successfully handled this message
  94. msg = self._create_message(response_type, **response_payload)
  95. response = web.Response(text=msg,
  96. status=HTTPStatus.OK,
  97. content_type='application/xml')
  98. return response
  99. async def handle_message(self, message_type, message_payload):
  100. if message_type in self.handlers:
  101. handler = self.handlers[message_type]
  102. result = handler(message_payload)
  103. if iscoroutine(result):
  104. result = await result
  105. if result is not None:
  106. response_type, response_payload = result
  107. if is_dataclass(response_payload):
  108. response_payload = asdict(response_payload)
  109. else:
  110. response_type, response_payload = 'oadrResponse', {}
  111. response_payload['vtn_id'] = self.vtn_id
  112. if 'ven_id' in message_payload:
  113. response_payload['ven_id'] = message_payload['ven_id']
  114. response_payload['response'] = {'request_id': message_payload.get('request_id', None),
  115. 'response_code': 200,
  116. 'response_description': 'OK'}
  117. response_payload['request_id'] = generate_id()
  118. else:
  119. response_type, response_payload = self.error_response(message_type,
  120. STATUS_CODES.COMPLIANCE_ERROR,
  121. "A message of type "
  122. f"{message_type} should not be "
  123. "sent to this endpoint")
  124. return response_type, response_payload
  125. def error_response(self, message_type, error_code, error_description):
  126. if message_type == 'oadrCreatePartyRegistration':
  127. response_type = 'oadrCreatedPartyRegistration'
  128. if message_type == 'oadrRequestEvent':
  129. response_type = 'oadrDistributeEvent'
  130. else:
  131. response_type = 'oadrResponse'
  132. response_payload = {'response': {'response_code': error_code,
  133. 'response_description': 'Certificate fingerprint mismatch'}}
  134. return response_type, response_payload