test_certificates.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import asyncio
  2. import pytest
  3. import os
  4. from functools import partial
  5. from openleadr import OpenADRServer, OpenADRClient, enable_default_logging
  6. from openleadr.utils import certificate_fingerprint
  7. from openleadr import errors
  8. from async_timeout import timeout
  9. enable_default_logging()
  10. CA_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'certificates', 'dummy_ca.crt')
  11. VTN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'certificates', 'dummy_vtn.crt')
  12. VTN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'certificates', 'dummy_vtn.key')
  13. VEN_CERT = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'certificates', 'dummy_ven.crt')
  14. VEN_KEY = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'certificates', 'dummy_ven.key')
  15. with open(VEN_CERT) as file:
  16. ven_fingerprint = certificate_fingerprint(file.read())
  17. with open(VTN_CERT) as file:
  18. vtn_fingerprint = certificate_fingerprint(file.read())
  19. async def lookup_fingerprint(ven_id):
  20. return ven_fingerprint
  21. async def on_create_party_registration(payload, future):
  22. if payload['fingerprint'] != ven_fingerprint:
  23. raise errors.FingerprintMismatch("The fingerprint of your TLS connection does not match the expected fingerprint. Your VEN is not allowed to register.")
  24. else:
  25. future.set_result(True)
  26. return 'ven1234', 'reg5678'
  27. @pytest.mark.asyncio
  28. async def test_ssl_certificates():
  29. loop = asyncio.get_event_loop()
  30. registration_future = loop.create_future()
  31. server = OpenADRServer(vtn_id='myvtn',
  32. http_cert=VTN_CERT,
  33. http_key=VTN_KEY,
  34. http_ca_file=CA_CERT,
  35. cert=VTN_CERT,
  36. key=VTN_KEY,
  37. fingerprint_lookup=lookup_fingerprint)
  38. server.add_handler('on_create_party_registration', partial(on_create_party_registration,
  39. future=registration_future))
  40. await server.run_async()
  41. await asyncio.sleep(1)
  42. # Run the client
  43. client = OpenADRClient(ven_name='myven',
  44. vtn_url='https://localhost:8080/OpenADR2/Simple/2.0b',
  45. cert=VEN_CERT,
  46. key=VEN_KEY,
  47. ca_file=CA_CERT,
  48. vtn_fingerprint=vtn_fingerprint)
  49. await client.run()
  50. # Wait for the registration to be triggered
  51. result = await asyncio.wait_for(registration_future, 1.0)
  52. assert client.registration_id == 'reg5678'
  53. await client.stop()
  54. await server.stop()
  55. await asyncio.sleep(0)
  56. @pytest.mark.asyncio
  57. async def test_ssl_certificates_wrong_cert():
  58. loop = asyncio.get_event_loop()
  59. registration_future = loop.create_future()
  60. server = OpenADRServer(vtn_id='myvtn',
  61. http_cert=VTN_CERT,
  62. http_key=VTN_KEY,
  63. http_ca_file=CA_CERT,
  64. cert=VTN_CERT,
  65. key=VTN_KEY,
  66. fingerprint_lookup=lookup_fingerprint)
  67. server.add_handler('on_create_party_registration', partial(on_create_party_registration,
  68. future=registration_future))
  69. await server.run_async()
  70. await asyncio.sleep(1)
  71. # Run the client
  72. client = OpenADRClient(ven_name='myven',
  73. vtn_url='https://localhost:8080/OpenADR2/Simple/2.0b',
  74. cert=VTN_CERT,
  75. key=VTN_KEY,
  76. ca_file=CA_CERT,
  77. vtn_fingerprint=vtn_fingerprint)
  78. await client.run()
  79. # Wait for the registration to be triggered
  80. with pytest.raises(asyncio.TimeoutError):
  81. await asyncio.wait_for(registration_future, timeout=0.5)
  82. assert client.registration_id is None
  83. await client.stop()
  84. await server.stop()
  85. await asyncio.sleep(0)