vtn_service.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # SPDX-License-Identifier: Apache-2.0
  2. # Copyright 2020 Contributors to OpenLEADR
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. from asyncio import iscoroutine
  13. from http import HTTPStatus
  14. import os
  15. from aiohttp import web
  16. from jinja2 import Environment, PackageLoader, select_autoescape
  17. from .. import errors
  18. from ..messaging import create_message, parse_message
  19. from ..utils import generate_id
  20. from dataclasses import is_dataclass, asdict
  21. class VTNService:
  22. def __init__(self, vtn_id):
  23. self.vtn_id = vtn_id
  24. self.handlers = {}
  25. for method in [getattr(self, attr) for attr in dir(self) if callable(getattr(self, attr))]:
  26. if hasattr(method, '__message_type__'):
  27. self.handlers[method.__message_type__] = method
  28. async def handler(self, request):
  29. """
  30. Handle all incoming POST requests.
  31. """
  32. content = await request.read()
  33. message_type, message_payload = self._parse_message(content)
  34. if message_type in self.handlers:
  35. handler = self.handlers[message_type]
  36. result = handler(message_payload)
  37. if iscoroutine(result):
  38. result = await result
  39. if result is not None:
  40. response_type, response_payload = result
  41. if is_dataclass(response_payload):
  42. response_payload = asdict(response_payload)
  43. else:
  44. response_type, response_payload = 'oadrResponse', {}
  45. response_payload['vtn_id'] = self.vtn_id
  46. if 'ven_id' in message_payload:
  47. response_payload['ven_id'] = message_payload['ven_id']
  48. response_payload['response'] = {'request_id': message_payload.get('request_id', None),
  49. 'response_code': 200,
  50. 'response_description': 'OK'}
  51. response_payload['request_id'] = generate_id()
  52. # Create the XML response
  53. msg = self._create_message(response_type, **response_payload)
  54. response = web.Response(text=msg,
  55. status=HTTPStatus.OK,
  56. content_type='application/xml')
  57. else:
  58. msg = self._create_message('oadrResponse',
  59. ven_id=message_payload.get('ven_id'),
  60. status_code=errorcodes.COMPLIANCE_ERROR,
  61. status_description=f"A message of type {message_type} should not be sent to this endpoint")
  62. response = web.Response(
  63. text=msg,
  64. status=HTTPStatus.BAD_REQUEST,
  65. content_type='application/xml')
  66. return response