diff --git a/charms/neutron-k8s/src/templates/cert_host.j2 b/charms/neutron-k8s/src/templates/cert_host.j2 index 5a642189..9cd290b8 100644 --- a/charms/neutron-k8s/src/templates/cert_host.j2 +++ b/charms/neutron-k8s/src/templates/cert_host.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.cert }} +{{ certificates.cert_main }} {% endif -%} diff --git a/charms/neutron-k8s/src/templates/key_host.j2 b/charms/neutron-k8s/src/templates/key_host.j2 index 47b94659..965890b5 100644 --- a/charms/neutron-k8s/src/templates/key_host.j2 +++ b/charms/neutron-k8s/src/templates/key_host.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.key }} +{{ certificates.key_main }} {% endif -%} diff --git a/charms/neutron-k8s/src/templates/neutron-ovn.crt.j2 b/charms/neutron-k8s/src/templates/neutron-ovn.crt.j2 index f339c6a9..988ab7a5 100644 --- a/charms/neutron-k8s/src/templates/neutron-ovn.crt.j2 +++ b/charms/neutron-k8s/src/templates/neutron-ovn.crt.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.ca_cert }} +{{ certificates.ca_cert_main }} {% endif -%} diff --git a/charms/octavia-k8s/src/templates/ovn_ca_cert.pem.j2 b/charms/octavia-k8s/src/templates/ovn_ca_cert.pem.j2 index f339c6a9..988ab7a5 100644 --- a/charms/octavia-k8s/src/templates/ovn_ca_cert.pem.j2 +++ b/charms/octavia-k8s/src/templates/ovn_ca_cert.pem.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.ca_cert }} +{{ certificates.ca_cert_main }} {% endif -%} diff --git a/charms/octavia-k8s/src/templates/ovn_certificate.pem.j2 b/charms/octavia-k8s/src/templates/ovn_certificate.pem.j2 index 5a642189..9cd290b8 100644 --- a/charms/octavia-k8s/src/templates/ovn_certificate.pem.j2 +++ b/charms/octavia-k8s/src/templates/ovn_certificate.pem.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.cert }} +{{ certificates.cert_main }} {% endif -%} diff --git a/charms/octavia-k8s/src/templates/ovn_private_key.pem.j2 b/charms/octavia-k8s/src/templates/ovn_private_key.pem.j2 index 47b94659..965890b5 100644 --- a/charms/octavia-k8s/src/templates/ovn_private_key.pem.j2 +++ b/charms/octavia-k8s/src/templates/ovn_private_key.pem.j2 @@ -1,3 +1,3 @@ {% if certificates -%} -{{ certificates.key }} +{{ certificates.key_main }} {% endif -%} diff --git a/charms/openstack-hypervisor/src/charm.py b/charms/openstack-hypervisor/src/charm.py index 6d41c66b..aaef581b 100755 --- a/charms/openstack-hypervisor/src/charm.py +++ b/charms/openstack-hypervisor/src/charm.py @@ -72,111 +72,68 @@ MTLS_USAGES = {x509.OID_SERVER_AUTH, x509.OID_CLIENT_AUTH} class MTlsCertificatesHandler(sunbeam_rhandlers.TlsCertificatesHandler): """Handler for certificates interface.""" - def update_relation_data(self): - """Update relation outside of relation context.""" - relations = self.model.relations[self.relation_name] - if len(relations) != 1: - logger.debug( - f"Unit has wrong number of {self.relation_name!r} relations." - ) - return - relation = relations[0] - csr = self._get_csr_from_relation_unit_data() - if not csr: - self._request_certificates() - return - certs = self._get_cert_from_relation_data(csr) - if "cert" not in certs or not self._has_certificate_mtls_extensions( - certs["cert"] - ): - logger.info( - "Requesting new certificates, current is missing mTLS extensions." - ) - relation.data[self.model.unit][ - "certificate_signing_requests" - ] = "[]" - self._request_certificates() + def csrs(self) -> dict[str, bytes]: + """Return a dict of generated csrs for self.key_names(). - def _has_certificate_mtls_extensions(self, certificate: str) -> bool: - """Check current certificate has mTLS extensions.""" - cert = x509.load_pem_x509_certificate(certificate.encode()) - for extension in cert.extensions: - if extension.oid != x509.OID_EXTENDED_KEY_USAGE: - continue - extension_oids = {ext.dotted_string for ext in extension.value} - mtls_oids = {oid.dotted_string for oid in MTLS_USAGES} - if mtls_oids.issubset(extension_oids): - return True - return False - - def _request_certificates(self): - """Request certificates from remote provider.""" + The method calling this method will ensure that all keys have a matching + csr. + """ # Lazy import to ensure this lib is only required if the charm # has this relation. - from charms.tls_certificates_interface.v1.tls_certificates import ( + from charms.tls_certificates_interface.v3.tls_certificates import ( generate_csr, ) - if self.ready: - logger.debug("Certificate request already complete.") - return + main_key = self._private_keys.get("main") + if not main_key: + return {} - if self.private_key: - logger.debug("Private key found, requesting certificates") - else: - logger.debug("Cannot request certificates, private key not found") - return - - csr = generate_csr( - private_key=self.private_key.encode(), - subject=socket.getfqdn(), - sans_dns=self.sans_dns, - sans_ip=self.sans_ips, - additional_critical_extensions=[ - x509.KeyUsage( - digital_signature=True, - content_commitment=False, - key_encipherment=True, - data_encipherment=False, - key_agreement=True, - key_cert_sign=False, - crl_sign=False, - encipher_only=False, - decipher_only=False, - ), - x509.ExtendedKeyUsage(MTLS_USAGES), - ], - ) - self.certificates.request_certificate_creation( - certificate_signing_request=csr - ) + return { + "main": generate_csr( + private_key=main_key.encode(), + subject=socket.getfqdn(), + sans_dns=self.sans_dns, + sans_ip=self.sans_ips, + additional_critical_extensions=[ + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=True, + data_encipherment=False, + key_agreement=True, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + x509.ExtendedKeyUsage(MTLS_USAGES), + ], + ) + } def context(self) -> dict: """Certificates context.""" - csr_from_unit = self._get_csr_from_relation_unit_data() - if not csr_from_unit: + certs = self.interface.get_assigned_certificates() + if len(certs) != len(self.key_names()): + return {} + # openstack-hypervisor only has a main key + csr = self.store.get_csr("main") + if csr is None: return {} - certs = self._get_cert_from_relation_data(csr_from_unit) - cert = certs["cert"] - ca_cert = certs["ca"] - ca_with_intermediates = certs["ca"] + "\n" + "\n".join(certs["chain"]) - - ctxt = { - "key": self.private_key, - "cert": cert, - "ca_cert": ca_cert, - "ca_with_intermediates": ca_with_intermediates, - } - return ctxt - - @property - def ready(self) -> bool: - """Whether handler ready for use.""" - try: - return super().ready - except KeyError: - return False + for cert in certs: + if cert.csr == csr: + return { + "key": self._private_keys["main"], + "cert": cert.certificate, + "ca_cert": cert.ca, + "ca_with_intermediates": cert.ca + + "\n" + + "\n".join(cert.chain), + } + else: + logger.warning("No certificate found for CSR main") + return {} class HypervisorOperatorCharm(sunbeam_charm.OSBaseOperatorCharm): diff --git a/charms/openstack-hypervisor/tests/unit/test_charm.py b/charms/openstack-hypervisor/tests/unit/test_charm.py index 70012753..3c6f71da 100644 --- a/charms/openstack-hypervisor/tests/unit/test_charm.py +++ b/charms/openstack-hypervisor/tests/unit/test_charm.py @@ -15,7 +15,6 @@ """Tests for Openstack hypervisor charm.""" import base64 -import json from unittest.mock import ( MagicMock, ) @@ -52,20 +51,9 @@ class TestCharm(test_utils.CharmTestCase): def initial_setup(self): """Setting up relations.""" - rel_id = self.harness.add_relation("certificates", "vault") - self.harness.add_relation_unit(rel_id, "vault/0") self.harness.update_config({"snap-channel": "essex/stable"}) self.harness.begin_with_initial_hooks() - csr = {"certificate_signing_request": test_utils.TEST_CSR} - self.harness.update_relation_data( - rel_id, - self.harness.charm.unit.name, - { - "ingress-address": "10.0.0.34", - "certificate_signing_requests": json.dumps([csr]), - }, - ) - test_utils.add_certificates_relation_certs(self.harness, rel_id) + test_utils.add_complete_certificates_relation(self.harness) ovs_rel_id = self.harness.add_relation("ovsdb-cms", "ovn-relay") self.harness.add_relation_unit(ovs_rel_id, "ovn-relay/0") self.harness.update_relation_data( diff --git a/charms/ovn-central-k8s/src/templates/cert_host.j2 b/charms/ovn-central-k8s/src/templates/cert_host.j2 index 0d92c108..b96ff66c 100644 --- a/charms/ovn-central-k8s/src/templates/cert_host.j2 +++ b/charms/ovn-central-k8s/src/templates/cert_host.j2 @@ -1,2 +1,2 @@ # {{ certificates }} -{{ certificates.cert }} +{{ certificates.cert_main }} diff --git a/charms/ovn-central-k8s/src/templates/key_host.j2 b/charms/ovn-central-k8s/src/templates/key_host.j2 index 63a9aebc..328db323 100644 --- a/charms/ovn-central-k8s/src/templates/key_host.j2 +++ b/charms/ovn-central-k8s/src/templates/key_host.j2 @@ -1 +1 @@ -{{ certificates.key }} +{{ certificates.key_main }} diff --git a/charms/ovn-central-k8s/src/templates/ovn-central.crt.j2 b/charms/ovn-central-k8s/src/templates/ovn-central.crt.j2 index 00cd0da9..c06330d2 100644 --- a/charms/ovn-central-k8s/src/templates/ovn-central.crt.j2 +++ b/charms/ovn-central-k8s/src/templates/ovn-central.crt.j2 @@ -1 +1 @@ -{{ certificates.ca_cert }} +{{ certificates.ca_cert_main }} diff --git a/charms/ovn-relay-k8s/src/templates/cert_host.j2 b/charms/ovn-relay-k8s/src/templates/cert_host.j2 index 0d92c108..b96ff66c 100644 --- a/charms/ovn-relay-k8s/src/templates/cert_host.j2 +++ b/charms/ovn-relay-k8s/src/templates/cert_host.j2 @@ -1,2 +1,2 @@ # {{ certificates }} -{{ certificates.cert }} +{{ certificates.cert_main }} diff --git a/charms/ovn-relay-k8s/src/templates/key_host.j2 b/charms/ovn-relay-k8s/src/templates/key_host.j2 index 63a9aebc..328db323 100644 --- a/charms/ovn-relay-k8s/src/templates/key_host.j2 +++ b/charms/ovn-relay-k8s/src/templates/key_host.j2 @@ -1 +1 @@ -{{ certificates.key }} +{{ certificates.key_main }} diff --git a/charms/ovn-relay-k8s/src/templates/ovn-central.crt.j2 b/charms/ovn-relay-k8s/src/templates/ovn-central.crt.j2 index 00cd0da9..c06330d2 100644 --- a/charms/ovn-relay-k8s/src/templates/ovn-central.crt.j2 +++ b/charms/ovn-relay-k8s/src/templates/ovn-central.crt.j2 @@ -1 +1 @@ -{{ certificates.ca_cert }} +{{ certificates.ca_cert_main }} diff --git a/fetch_libs.sh b/fetch_libs.sh index 4f5ce5eb..3b18cfeb 100755 --- a/fetch_libs.sh +++ b/fetch_libs.sh @@ -14,7 +14,7 @@ charmcraft fetch-lib charms.operator_libs_linux.v0.sysctl charmcraft fetch-lib charms.operator_libs_linux.v2.snap charmcraft fetch-lib charms.prometheus_k8s.v0.prometheus_scrape charmcraft fetch-lib charms.rabbitmq_k8s.v0.rabbitmq -charmcraft fetch-lib charms.tls_certificates_interface.v1.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates charmcraft fetch-lib charms.traefik_k8s.v2.ingress charmcraft fetch-lib charms.traefik_route_k8s.v0.traefik_route charmcraft fetch-lib charms.vault_k8s.v0.vault_kv diff --git a/libs/external/lib/charms/tls_certificates_interface/v1/tls_certificates.py b/libs/external/lib/charms/tls_certificates_interface/v3/tls_certificates.py similarity index 53% rename from libs/external/lib/charms/tls_certificates_interface/v1/tls_certificates.py rename to libs/external/lib/charms/tls_certificates_interface/v3/tls_certificates.py index be171d8e..33f34b62 100644 --- a/libs/external/lib/charms/tls_certificates_interface/v1/tls_certificates.py +++ b/libs/external/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,16 +7,19 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. +Pre-requisites: + - Juju >= 3.0 + ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v1.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography +- cryptography >= 42.0.0 Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -36,10 +39,10 @@ this example, the provider charm is storing its private key using a peer relatio Example: ```python -from charms.tls_certificates_interface.v1.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV1, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,12 +62,14 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV1(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( - self.certificates.on.certificate_request, self._on_certificate_request + self.certificates.on.certificate_request, + self._on_certificate_request ) self.framework.observe( - self.certificates.on.certificate_revoked, self._on_certificate_revocation_request + self.certificates.on.certificate_revocation_request, + self._on_certificate_revocation_request ) self.framework.observe(self.on.install, self._on_install) @@ -106,6 +111,7 @@ class ExampleProviderCharm(CharmBase): ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -124,17 +130,18 @@ this example, the requirer charm is storing its certificates using a peer relati Example: ```python -from charms.tls_certificates_interface.v1.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV1, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationJoinedEvent +from ops.charm import CharmBase, RelationCreatedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus +from typing import Union class ExampleRequirerCharm(CharmBase): @@ -142,10 +149,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV1(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_joined, self._on_certificates_relation_joined + self.on.certificates_relation_created, self._on_certificates_relation_created ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -154,7 +161,11 @@ class ExampleRequirerCharm(CharmBase): self.certificates.on.certificate_expiring, self._on_certificate_expiring ) self.framework.observe( - self.certificates.on.certificate_revoked, self._on_certificate_revoked + self.certificates.on.certificate_invalidated, self._on_certificate_invalidated + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, + self._on_all_certificates_invalidated ) def _on_install(self, event) -> None: @@ -169,7 +180,7 @@ class ExampleRequirerCharm(CharmBase): {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -196,7 +207,9 @@ class ExampleRequirerCharm(CharmBase): replicas_relation.data[self.app].update({"chain": event.chain}) self.unit.status = ActiveStatus() - def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: + def _on_certificate_expiring( + self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] + ) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -216,12 +229,7 @@ class ExampleRequirerCharm(CharmBase): ) replicas_relation.data[self.app].update({"csr": new_csr.decode()}) - def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return + def _certificate_revoked(self) -> None: old_csr = replicas_relation.data[self.app].get("csr") private_key_password = replicas_relation.data[self.app].get("private_key_password") private_key = replicas_relation.data[self.app].get("private_key") @@ -240,44 +248,82 @@ class ExampleRequirerCharm(CharmBase): replicas_relation.data[self.app].pop("chain") self.unit.status = WaitingStatus("Waiting for new certificate") + def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + if event.reason == "revoked": + self._certificate_revoked() + if event.reason == "expired": + self._on_certificate_expiring(event) + + def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: + # Do what you want with this information, probably remove all certificates. + pass + if __name__ == "__main__": main(ExampleRequirerCharm) ``` + +You can relate both charms by running: + +```bash +juju relate +``` + """ # noqa: D405, D410, D411, D214, D416 import copy import json import logging import uuid -from datetime import datetime, timedelta +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] -from ops.charm import CharmBase, CharmEvents, RelationChangedEvent, UpdateStatusEvent +from jsonschema import exceptions, validate +from ops.charm import ( + CharmBase, + CharmEvents, + RelationBrokenEvent, + RelationChangedEvent, + SecretExpiredEvent, +) from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 1 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 12 +LIBPATCH = 15 +PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -298,7 +344,10 @@ REQUIRER_JSON_SCHEMA = { "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -309,7 +358,7 @@ REQUIRER_JSON_SCHEMA = { PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -383,6 +432,58 @@ PROVIDER_JSON_SCHEMA = { logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -401,7 +502,7 @@ class CertificateAvailableEvent(EventBase): self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -410,12 +511,16 @@ class CertificateAvailableEvent(EventBase): } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -426,7 +531,7 @@ class CertificateExpiringEvent(EventBase): Args: handle (Handle): Juju framework handle certificate (str): TLS Certificate - expiry (str): Datetime string reprensenting the time at which the certificate + expiry (str): Datetime string representing the time at which the certificate won't be valid anymore. """ super().__init__(handle) @@ -434,88 +539,96 @@ class CertificateExpiringEvent(EventBase): self.expiry = expiry def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] -class CertificateExpiredEvent(EventBase): - """Charm Event triggered when a TLS certificate is expired.""" - - def __init__(self, handle: Handle, certificate: str): - super().__init__(handle) - self.certificate = certificate - - def snapshot(self) -> dict: - """Returns snapshot.""" - return {"certificate": self.certificate} - - def restore(self, snapshot: dict): - """Restores snapshot.""" - self.certificate = snapshot["certificate"] - - -class CertificateRevokedEvent(EventBase): - """Charm Event triggered when a TLS certificate is revoked.""" +class CertificateInvalidatedEvent(EventBase): + """Charm Event triggered when a TLS certificate is invalidated.""" def __init__( self, handle: Handle, + reason: Literal["expired", "revoked"], certificate: str, certificate_signing_request: str, ca: str, chain: List[str], - revoked: bool, ): super().__init__(handle) - self.certificate = certificate + self.reason = reason self.certificate_signing_request = certificate_signing_request + self.certificate = certificate self.ca = ca self.chain = chain - self.revoked = revoked def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { - "certificate": self.certificate, + "reason": self.reason, "certificate_signing_request": self.certificate_signing_request, + "certificate": self.certificate, "ca": self.ca, "chain": self.chain, - "revoked": self.revoked, } def restore(self, snapshot: dict): - """Restores snapshot.""" - self.certificate = snapshot["certificate"] + """Restore snapshot.""" + self.reason = snapshot["reason"] self.certificate_signing_request = snapshot["certificate_signing_request"] + self.certificate = snapshot["certificate"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] - self.revoked = snapshot["revoked"] + + +class AllCertificatesInvalidatedEvent(EventBase): + """Charm Event triggered when all TLS certificates are invalidated.""" + + def __init__(self, handle: Handle): + super().__init__(handle) + + def snapshot(self) -> dict: + """Return snapshot.""" + return {} + + def restore(self, snapshot: dict): + """Restore snapshot.""" + pass class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -536,7 +649,7 @@ class CertificateRevocationRequestEvent(EventBase): self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -545,33 +658,100 @@ class CertificateRevocationRequestEvent(EventBase): } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: - """Loads relation data from the relation data bag. +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: + """Load relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ - certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + certificate_data = {} + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. + + Args: + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. + + Returns: + datetime: Time to notify the user about the certificate expiry. + """ + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = ( + expiry_time - timedelta(hours=provider_recommended_notification_time) + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = ( + expiry_time - timedelta(hours=requirer_recommended_notification_time) + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) + + def generate_ca( private_key: bytes, subject: str, @@ -579,11 +759,11 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generates a CA Certificate. + """Generate a CA Certificate. Args: private_key (bytes): Private key - subject (str): Certificate subject + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). private_key_password (bytes): Private key password validity (int): Certificate validity time (in days) country (str): Certificate Issuing country @@ -594,7 +774,7 @@ def generate_ca( private_key_object = serialization.load_pem_private_key( private_key, password=private_key_password ) - subject = issuer = x509.Name( + subject_name = x509.Name( [ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), @@ -604,14 +784,25 @@ def generate_ca( private_key_object.public_key() # type: ignore[arg-type] ) subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) cert = ( x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) + .subject_name(subject_name) + .issuer_name(subject_name) .public_key(private_key_object.public_key()) # type: ignore[arg-type] .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( x509.AuthorityKeyIdentifier( @@ -621,6 +812,7 @@ def generate_ca( ), critical=False, ) + .add_extension(key_usage, critical=True) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, @@ -630,6 +822,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -637,8 +928,9 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: - """Generates a TLS certificate based on a CSR. + """Generate a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -647,13 +939,15 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate """ csr_object = x509.load_pem_x509_csr(csr) subject = csr_object.subject - issuer = x509.load_pem_x509_certificate(ca).issuer + ca_pem = x509.load_pem_x509_certificate(ca) + issuer = ca_pem.issuer private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) certificate_builder = ( @@ -662,81 +956,36 @@ def generate_certificate( .issuer_name(issuer) .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), - ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext - - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generates a private key. + """Generate a private key. Args: password (bytes): Password for decrypting the private key @@ -753,20 +1002,24 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes -def generate_csr( +def generate_csr( # noqa: C901 private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -774,24 +1027,26 @@ def generate_csr( sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generates a CSR using private key and subject. + """Generate a CSR using private key and subject. Args: private_key (bytes): Private key - subject (str): CSR Subject. + subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's subject name. Always leave to "True" when the CSR is used to request certificates using the tls-certificates relation. organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) sans_oid (list): List of registered ID SANs sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) sans_ip (list): List of IP subject alternative names - additional_critical_extensions (list): List if critical additional extension objects. + additional_critical_extensions (list): List of critical additional extension objects. Object must be a x509 ExtensionType. Returns: @@ -810,6 +1065,12 @@ def generate_csr( subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] @@ -832,6 +1093,66 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -844,14 +1165,14 @@ class CertificatesRequirerCharmEvents(CharmEvents): certificate_available = EventSource(CertificateAvailableEvent) certificate_expiring = EventSource(CertificateExpiringEvent) - certificate_expired = EventSource(CertificateExpiredEvent) - certificate_revoked = EventSource(CertificateRevokedEvent) + certificate_invalidated = EventSource(CertificateInvalidatedEvent) + all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV1(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" - on = CertificatesProviderCharmEvents() + on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] def __init__(self, charm: CharmBase, relationship_name: str): super().__init__(charm, relationship_name) @@ -861,6 +1182,22 @@ class TLSCertificatesProvidesV1(Object): self.charm = charm self.relationship_name = relationship_name + def _load_app_relation_data(self, relation: Relation) -> dict: + """Load relation data from the application relation data bag. + + Json loads all data. + + Args: + relation: Relation data from the application databag + + Returns: + dict: Relation data in dict format. + """ + # If unit is not leader, it does not try to reach relation data. + if not self.model.unit.is_leader(): + return {} + return _load_relation_data(relation.data[self.charm.app]) + def _add_certificate( self, relation_id: int, @@ -868,8 +1205,9 @@ class TLSCertificatesProvidesV1(Object): certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Adds certificate to relation data. + """Add certificate to relation data. Args: relation_id (int): Relation id @@ -877,6 +1215,8 @@ class TLSCertificatesProvidesV1(Object): certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -894,8 +1234,9 @@ class TLSCertificatesProvidesV1(Object): "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) if new_certificate in certificates: @@ -910,7 +1251,7 @@ class TLSCertificatesProvidesV1(Object): certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Removes certificate from a given relation based on user provided certificate or csr. + """Remove certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -928,7 +1269,7 @@ class TLSCertificatesProvidesV1(Object): raise RuntimeError( f"Relation {self.relationship_name} with relation id {relation_id} does not exist" ) - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) for certificate_dict in certificates: @@ -941,29 +1282,13 @@ class TLSCertificatesProvidesV1(Object): certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: - """Revokes all certificates of this provider. + """Revoke all certificates of this provider. This method is meant to be used when the Root CA has changed. """ for relation in self.model.relations[self.relationship_name]: - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) for certificate in provider_certificates: certificate["revoked"] = True @@ -976,8 +1301,9 @@ class TLSCertificatesProvidesV1(Object): ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: - """Adds certificates to relation data. + """Add certificates to relation data. Args: certificate (str): Certificate @@ -985,10 +1311,14 @@ class TLSCertificatesProvidesV1(Object): ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None """ + if not self.model.unit.is_leader(): + return certificates_relation = self.model.get_relation( relation_name=self.relationship_name, relation_id=relation_id ) @@ -1004,10 +1334,11 @@ class TLSCertificatesProvidesV1(Object): certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: - """Removes a given certificate from relation data. + """Remove a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1021,8 +1352,67 @@ class TLSCertificatesProvidesV1(Object): for certificate_relation in certificates_relation: self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. + + Returns: + List: List of ProviderCertificate objects + """ + certificates: List[ProviderCertificate] = [] + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + for certificate in provider_certificates: + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() + ) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), + ) + certificates.append(provider_certificate) + return certificates + def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggerred on relation changed event. + """Handle relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1036,120 +1426,258 @@ class TLSCertificatesProvidesV1(Object): Returns: None """ - assert event.unit is not None - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = _load_relation_data(event.relation.data[self.charm.app]) - if not self._relation_data_is_valid(requirer_relation_data): - logger.warning( - f"Relation data did not pass JSON Schema validation: {requirer_relation_data}" - ) + if event.unit is None: + logger.error("Relation_changed event does not have a unit.") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + if not self.model.unit.is_leader(): + return + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in requirer_csrs - ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, - relation_id=event.relation.id, + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revokes certificates for which no unit has a CSR. + """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare agains the list of CSRS for all units - of a given relationship. + Goes through all generated certificates and compare against the list of CSRs for all units. + + Returns: + None + """ + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.csr not in list_of_csrs: + self.on.certificate_revocation_request.emit( + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, + ) + self.remove_certificate(certificate=certificate.certificate) + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCSR]: + """Return CSR's for which no certificate has been issued. Args: relation_id (int): Relation id Returns: - None + list: List of RequirerCSR objects. """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. + + It returns CSRs from all relations if relation_id is not specified. + CSRs are returned per relation id, application name and unit name. + + Returns: + list: List[RequirerCSR] + """ + relation_csrs: List[RequirerCSR] = [] + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = _load_relation_data(certificates_relation.data[self.charm.app]) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) - for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], - ) - self.remove_certificate(certificate=certificate["certificate"]) + + for relation in relations: + for unit in relation.units: + requirer_relation_data = _load_relation_data(relation.data[unit]) + unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs + + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR. + + Args: + app_name (str): Application name that the CSR belongs to. + csr (str): Certificate Signing Request. + relation_id (Optional[int]): Relation ID + + Returns: + bool: True/False depending on whether a certificate has been issued for the given CSR. + """ + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) + return False -class TLSCertificatesRequiresV1(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificatesRequirerCharmEvents() + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): - """Generates/use private key and observes relation changed event. + """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time self.framework.observe( charm.on[relationship_name].relation_changed, self._on_relation_changed ) - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe( + charm.on[relationship_name].relation_broken, self._on_relation_broken + ) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer CSR's from relation data.""" + def get_requirer_csrs(self) -> List[RequirerCSR]: + """Return list of requirer's CSRs from relation unit data. + + Returns: + list: List of RequirerCSR objects. + """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + return [] + requirer_csrs = [] requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of provider CSR's from relation data.""" + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + logger.debug("No relation: %s", self.relationship_name) + return [] if not relation.app: - raise RuntimeError(f"Remote app for relation {self.relationship_name} does not exist") + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str) -> None: - """Adds CSR to relation data. + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: + """Add CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1160,16 +1688,24 @@ class TLSCertificatesRequiresV1(Object): f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { + "certificate_signing_request": csr, + "ca": is_ca, + } + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: - """Removes CSR from relation data. + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: + """Remove CSR from relation data. Args: csr (str): Certificate signing request @@ -1183,36 +1719,44 @@ class TLSCertificatesRequiresV1(Object): f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not self.get_requirer_csrs(): + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: + if requirer_csr["certificate_signing_request"] == csr: + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None """ relation = self.model.get_relation(self.relationship_name) if not relation: - message = ( + raise RuntimeError( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - logger.error(message) - raise RuntimeError(message) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1224,13 +1768,13 @@ class TLSCertificatesRequiresV1(Object): Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renews certificate. + """Renew certificate. Removes old CSR from relation data and adds new one. @@ -1252,24 +1796,69 @@ class TLSCertificatesRequiresV1(Object): ) logger.info("Certificate renewal request completed.") - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. - - Args: - certificates_data: Certificate data in dict format. + def get_assigned_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit. Returns: - bool: Whether relation data is valid. + List: List[ProviderCertificate] """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List: List[ProviderCertificate] + """ + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + if not cert.expiry_time or not cert.expiry_notification_time: + continue + if datetime.now(timezone.utc) > cert.expiry_notification_time: + expiring_certificates.append(cert) + return expiring_certificates + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[RequirerCSR]: + """Get the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + + Returns: + List of RequirerCSR objects. + """ + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + csrs.append(requirer_csr) + + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed events. + """Handle relation changed event. + + Goes through all providers certificates that match a requested CSR. + + If the provider certificate is revoked, emit a CertificateInvalidateEvent, + otherwise emit a CertificateAvailableEvent. + + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. Args: event: Juju event @@ -1277,84 +1866,154 @@ class TLSCertificatesRequiresV1(Object): Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.warning(f"No relation: {self.relationship_name}") + if not event.app: + logger.warning("No remote app in relation - Skipping") return - if not relation.app: - logger.warning(f"No remote app in relation: {self.relationship_name}") - return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning( - f"Provider relation data did not pass JSON Schema validation: " - f"{event.relation.data[relation.app]}" - ) + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - self.on.certificate_revoked.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], - revoked=True, + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) + if certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret( + label=f"{LIBID}-{csr_in_sha256_hex}") + secret.remove_all_revisions() + self.on.certificate_invalidated.emit( + reason="revoked", + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: + try: + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate), + ) + except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: + """Return the expiry time or expiry notification time. - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. + Extracts the expiry time from the provided certificate, calculates the + expiry notification time and return the closest of the two, that is in + the future. Args: - event (UpdateStatusEvent): Juju event + certificate: ProviderCertificate object + + Returns: + Optional[datetime]: None if the certificate expiry time cannot be read, + next expiry time otherwise. + """ + if not certificate.expiry_time or not certificate.expiry_notification_time: + return None + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Handle Relation Broken Event. + + Emitting `all_certificates_invalidated` from `relation-broken` rather + than `relation-departed` since certs are stored in app data. + + Args: + event: Juju event Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.warning(f"No relation: {self.relationship_name}") + self.on.all_certificates_invalidated.emit() + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Loads the certificate from the secret, and will emit 1 of 2 + events. + + If the certificate is not yet expired, emits CertificateExpiringEvent + and updates the expiry time of the secret to the exact expiry time on + the certificate. + + If the certificate is expired, emits CertificateInvalidedEvent and + deletes the secret. + + Args: + event (SecretExpiredEvent): Juju event + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - if not relation.app: - logger.warning(f"No remote app in relation: {self.relationship_name}") + csr = event.secret.get_content()["csr"] + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: + # A secret expired but we did not find matching certificate. Cleaning up + event.secret.remove_all_revisions() return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning( - f"Provider relation data did not pass JSON Schema validation: " - f"{relation.data[relation.app]}" + + if not provider_certificate.expiry_time: + # A secret expired but matching certificate is invalid. Cleaning up + event.secret.remove_all_revisions() + return + + if datetime.now(timezone.utc) < provider_certificate.expiry_time: + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=provider_certificate.certificate, + expiry=provider_certificate.expiry_time.isoformat(), ) - return - for certificate_dict in self._provider_certificates: - certificate = certificate_dict["certificate"] - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError: - logger.warning("Could not load certificate.") + event.secret.set_info( + expire=provider_certificate.expiry_time, + ) + else: + logger.warning("Certificate is expired") + self.on.certificate_invalidated.emit( + reason="expired", + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + self.request_certificate_revocation(provider_certificate.certificate.encode()) + event.secret.remove_all_revisions() + + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: + """Return the certificate that match the given CSR.""" + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - time_difference = certificate_object.not_valid_after - datetime.utcnow() - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_expired.emit(certificate=certificate) - self.request_certificate_revocation(certificate.encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate, expiry=certificate_object.not_valid_after.isoformat() - ) + return provider_certificate + return None diff --git a/ops-sunbeam/ops_sunbeam/charm.py b/ops-sunbeam/ops_sunbeam/charm.py index 90c242b8..ff3ba0c8 100644 --- a/ops-sunbeam/ops_sunbeam/charm.py +++ b/ops-sunbeam/ops_sunbeam/charm.py @@ -431,11 +431,11 @@ class OSBaseOperatorCharm(ops.charm.CharmBase): if isinstance(event, RabbitMQGoneAwayEvent): _is_broken = True case "certificates": - from charms.tls_certificates_interface.v1.tls_certificates import ( - CertificateExpiredEvent, + from charms.tls_certificates_interface.v3.tls_certificates import ( + AllCertificatesInvalidatedEvent, ) - if isinstance(event, CertificateExpiredEvent): + if isinstance(event, AllCertificatesInvalidatedEvent): _is_broken = True case "ovsdb-cms": from charms.ovn_central_k8s.v0.ovsdb import ( diff --git a/ops-sunbeam/ops_sunbeam/relation_handlers.py b/ops-sunbeam/ops_sunbeam/relation_handlers.py index 230da737..0739b232 100644 --- a/ops-sunbeam/ops_sunbeam/relation_handlers.py +++ b/ops-sunbeam/ops_sunbeam/relation_handlers.py @@ -14,11 +14,13 @@ """Base classes for defining a charm using the Operator framework.""" +import abc import hashlib import json import logging import secrets import string +import typing from typing import ( Callable, Dict, @@ -783,33 +785,118 @@ class CephClientHandler(RelationHandler): return ctxt +class _StoreEntry(typing.TypedDict, total=False): + """Type definition for a store entry.""" + + private_key: str + csr: str + + +class _Store(abc.ABC): + + @abc.abstractmethod + def ready(self) -> bool: + """Check if store is ready.""" + ... + + @abc.abstractmethod + def get_entries(self) -> dict[str, _StoreEntry]: + """Get store dict from relation data.""" + ... + + @abc.abstractmethod + def save_entries(self, entries: dict[str, _StoreEntry]): + """Save store dict to relation data.""" + ... + + def get_entry(self, name: str) -> _StoreEntry | None: + """Return store entry.""" + if not self.ready(): + logger.debug("Store not ready, cannot get entry.") + return None + return self.get_entries().get(name) + + def save_entry(self, name: str, entry: _StoreEntry): + """Save store entry.""" + if not self.ready(): + logger.debug("Store not ready, cannot set entry.") + return + store = self.get_entries() + store[name] = entry + self.save_entries(store) + + def get_private_key(self, name: str) -> str | None: + """Return private key.""" + if entry := self.get_entry(name): + return entry.get("private_key") + return None + + def get_csr(self, name: str) -> str | None: + """Return csr.""" + if entry := self.get_entry(name): + return entry.get("csr") + return None + + def set_private_key(self, name: str, private_key: str): + """Update private key.""" + entry = self.get_entry(name) or {} + entry["private_key"] = private_key + self.save_entry(name, entry) + + def set_csr(self, name: str, csr: bytes): + """Update csr.""" + entry = self.get_entry(name) or {} + entry["csr"] = csr.decode() + self.save_entry(name, entry) + + def delete_csr(self, name: str): + """Delete csr.""" + entry = self.get_entry(name) or {} + entry.pop("csr", None) + self.save_entry(name, entry) + + class TlsCertificatesHandler(RelationHandler): """Handler for certificates interface.""" - class PeerKeyStore: - """Store private key sercret id in peer storage relation.""" + if typing.TYPE_CHECKING: + from charms.tls_certificates_interface.v3.tls_certificates import ( + TLSCertificatesRequiresV3, + ) - def __init__(self, relation, unit): + interface: TLSCertificatesRequiresV3 + + class PeerStore(_Store): + """Store private key secret id in peer storage relation.""" + + STORE_KEY: str = "tls-store" + + def __init__( + self, relation: ops.Relation, entity: ops.Unit | ops.Application + ): self.relation = relation - self.unit = unit + self.entity = entity - def store_ready(self) -> bool: + def ready(self) -> bool: """Check if store is ready.""" - return bool(self.relation) + return bool(self.relation) and self.relation.active - def get_private_key(self) -> str: - """Return private key.""" - try: - key = self.relation.data[self.unit].get("private_key") - except AttributeError: - key = None - return key + def get_entries(self) -> dict[str, _StoreEntry]: + """Get store dict from relation data.""" + if not self.ready(): + return {} + return json.loads( + self.relation.data[self.entity].get(self.STORE_KEY, "{}") + ) - def set_private_key(self, value: str): - """Update private key.""" - self.relation.data[self.unit]["private_key"] = value + def save_entries(self, entries: dict[str, _StoreEntry]): + """Save store dict to relation data.""" + if self.ready(): + self.relation.data[self.entity][self.STORE_KEY] = json.dumps( + entries + ) - class LocalDBKeyStore: + class LocalDBStore(_Store): """Store private key sercret id in local unit db. This is a fallback for when the peer relation is not @@ -819,56 +906,111 @@ class TlsCertificatesHandler(RelationHandler): def __init__(self, state_db): self.state_db = state_db try: - self.state_db.private_key + self.state_db.tls_store except AttributeError: - self.state_db.private_key = None + self.state_db.tls_store = "{}" - def store_ready(self) -> bool: + def ready(self) -> bool: """Check if store is ready.""" return True - def get_private_key(self) -> str: - """Return private key.""" - return self.state_db.private_key + def get_entries(self) -> dict[str, _StoreEntry]: + """Get store dict from relation data.""" + return json.loads(self.state_db.tls_store) - def set_private_key(self, value: str): - """Update private key.""" - self.state_db.private_key = value + def save_entries(self, entries: dict[str, _StoreEntry]): + """Save store dict to relation data.""" + self.state_db.tls_store = json.dumps(entries) def __init__( self, - charm: ops.charm.CharmBase, + charm: ops.CharmBase, relation_name: str, callback_f: Callable, - sans_dns: List[str] = None, - sans_ips: List[str] = None, + sans_dns: list[str] | None = None, + sans_ips: list[str] | None = None, mandatory: bool = False, ) -> None: """Run constructor.""" - self._private_key = None + self._private_keys: dict[str, str] = {} self.sans_dns = sans_dns self.sans_ips = sans_ips super().__init__(charm, relation_name, callback_f, mandatory) try: - self.store = self.PeerKeyStore( - self.model.get_relation("peers"), self.charm.model.unit + self.store = self.PeerStore( + self.model.get_relation("peers"), self.get_entity() ) except KeyError: - self.store = self.LocalDBKeyStore(charm._state) - self.setup_private_key() + if self.app_managed_certificates(): + raise RuntimeError( + "Application managed certificates require a peer relation" + ) + self.store = self.LocalDBStore(charm._state) + self.setup_private_keys() - def setup_event_handler(self) -> None: + def get_entity(self) -> ops.Unit | ops.Application: + """Return the entity for the key store. + + Defaults to the unit. + """ + return self.charm.model.unit + + def i_am_allowed(self) -> bool: + """Whether this unit is allowed to modify the store.""" + i_need_to_be_leader = self.app_managed_certificates() + if i_need_to_be_leader: + return self.charm.unit.is_leader() + + return True + + def app_managed_certificates(self) -> bool: + """Whether the application manages its own certificates.""" + return isinstance(self.get_entity(), ops.Application) + + def key_names(self) -> list[str]: + """Return the key names managed by this relation. + + First key is considered as default key. + """ + return ["main"] + + def csrs(self) -> dict[str, bytes]: + """Return a dict of generated csrs for self.key_names(). + + The method calling this method will ensure that all keys have a matching + csr. + """ + # Lazy import to ensure this lib is only required if the charm + # has this relation. + from charms.tls_certificates_interface.v3.tls_certificates import ( + generate_csr, + ) + + main_key = self._private_keys.get("main") + if not main_key: + return {} + return { + "main": generate_csr( + private_key=main_key.encode(), + subject=self.get_entity().name.replace("/", "-"), + sans_dns=self.sans_dns, + sans_ip=self.sans_ips, + ) + } + + def setup_event_handler(self) -> ops.Object: """Configure event handlers for tls relation.""" logger.debug("Setting up certificates event handler") # Lazy import to ensure this lib is only required if the charm # has this relation. - from charms.tls_certificates_interface.v1.tls_certificates import ( - TLSCertificatesRequiresV1, + from charms.tls_certificates_interface.v3.tls_certificates import ( + TLSCertificatesRequiresV3, ) - self.certificates = TLSCertificatesRequiresV1( + self.certificates = TLSCertificatesRequiresV3( self.charm, "certificates" ) + self.framework.observe( self.charm.on.certificates_relation_joined, self._on_certificates_relation_joined, @@ -886,26 +1028,25 @@ class TlsCertificatesHandler(RelationHandler): self._on_certificate_expiring, ) self.framework.observe( - self.certificates.on.certificate_expired, - self._on_certificate_expired, + self.certificates.on.certificate_invalidated, + self._on_certificate_invalidated, + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, + self._on_all_certificate_invalidated, ) return self.certificates - def setup_private_key(self) -> None: + def _setup_private_key(self, key: str): """Create and store private key if needed.""" # Lazy import to ensure this lib is only required if the charm # has this relation. - from charms.tls_certificates_interface.v1.tls_certificates import ( + from charms.tls_certificates_interface.v3.tls_certificates import ( generate_private_key, ) - if not self.store.store_ready(): - logger.debug("Store not ready, cannot generate key") - return - - if self.store.get_private_key(): + if private_key_secret_id := self.store.get_private_key(key): logger.debug("Private key already present") - private_key_secret_id = self.store.get_private_key() try: private_key_secret = self.model.get_secret( id=private_key_secret_id @@ -924,29 +1065,49 @@ class TlsCertificatesHandler(RelationHandler): private_key_secret = self.model.get_secret( id=private_key_secret_id ) - self._private_key = ( - private_key_secret.get_content(refresh=True) - .get("private-key") - .encode() + self._private_keys[key] = private_key_secret.get_content( + refresh=True + )["private-key"] + return + + self._private_keys[key] = generate_private_key().decode() + private_key_secret = self.get_entity().add_secret( + {"private-key": self._private_keys[key]}, + label=f"{self.get_entity().name}-{key}-private-key", + ) + + self.store.set_private_key( + key, typing.cast(str, private_key_secret.id) + ) + + def setup_private_keys(self) -> None: + """Create and store private key if needed.""" + if not self.i_am_allowed(): + logger.debug( + "Unit is not allow to handle private keys, skipping setup" ) return - self._private_key = generate_private_key() - private_key_secret = self.model.unit.add_secret( - {"private-key": self._private_key.decode()}, - label=f"{self.charm.model.unit}-private-key", - ) + if not self.store.ready(): + logger.debug("Store not ready, cannot generate key") + return - self.store.set_private_key(private_key_secret.id) + keys = self.key_names() + if not keys: + raise RuntimeError("No keys to generate, this is always a bug.") + + for key in keys: + self._setup_private_key(key) @property - def private_key(self): - """Private key for certificates.""" - if self._private_key: - return self._private_key.decode() - else: - # Private key has not been set yet - return None + def private_key(self) -> str | None: + """Private key for certificates. + + Return the first key from key_names. + """ + if private_key := self._private_keys.get(self.key_names()[0]): + return private_key + return None def update_relation_data(self): """Request certificates outside of relation context.""" @@ -957,120 +1118,131 @@ class TlsCertificatesHandler(RelationHandler): "Not updating certificate request data, no relation found" ) - def _on_certificates_relation_joined( - self, event: ops.framework.EventBase - ) -> None: + def _on_certificates_relation_joined(self, event: ops.EventBase) -> None: """Request certificates in response to relation join event.""" self._request_certificates() - def _request_certificates(self): + def _request_certificates(self, renew=False): """Request certificates from remote provider.""" - # Lazy import to ensure this lib is only required if the charm - # has this relation. - from charms.tls_certificates_interface.v1.tls_certificates import ( - generate_csr, - ) + if not self.i_am_allowed(): + logger.debug( + "Unit is not allow to handle private keys, skipping setup" + ) + return if self.ready: logger.debug("Certificate request already complete.") return - if self.private_key: - logger.debug("Private key found, requesting certificates") - else: - logger.debug("Cannot request certificates, private key not found") + keys = self.key_names() + if set(keys) != set(self._private_keys.keys()): + logger.debug("Not all private keys are setup, skipping request.") return - csr = generate_csr( - private_key=self.private_key.encode(), - subject=self.charm.model.unit.name.replace("/", "-"), - sans_dns=self.sans_dns, - sans_ip=self.sans_ips, - ) - self.certificates.request_certificate_creation( - certificate_signing_request=csr - ) + csrs = self.csrs() - def _on_certificates_relation_broken( - self, event: ops.framework.EventBase - ) -> None: + if set(keys) != set(csrs.keys()): + raise RuntimeError( + "Mismatch between keys and csrs, this is always a bug." + ) + + for name, csr in csrs.items(): + previous_csr = self.store.get_csr(name) + csr = csr.strip() + if renew and previous_csr: + self.certificates.request_certificate_renewal( + old_certificate_signing_request=previous_csr.encode(), + new_certificate_signing_request=csr, + ) + self.store.set_csr(name, csr) + elif previous_csr: + logger.debug( + "CSR already exists for %s, skipping request.", name + ) + else: + self.certificates.request_certificate_creation( + certificate_signing_request=csr + ) + self.store.set_csr(name, csr) + + def _on_certificates_relation_broken(self, event: ops.EventBase) -> None: if self.mandatory: self.status.set(BlockedStatus("integration missing")) - def _on_certificate_available( - self, event: ops.framework.EventBase - ) -> None: + def _on_certificate_available(self, event: ops.EventBase) -> None: self.callback_f(event) - def _on_certificate_expiring(self, event: ops.framework.EventBase) -> None: - logger.warning("Certificate getting expired") + def _on_certificate_expiring(self, event: ops.EventBase) -> None: self.status.set(ActiveStatus("Certificates are getting expired soon")) + logger.warning("Certificate getting expired, requesting new ones.") + self._request_certificates(renew=True) + self.callback_f(event) - def _on_certificate_expired(self, event: ops.framework.EventBase) -> None: - logger.warning("Certificate expired") - self.status.set(BlockedStatus("Certificates expired")) + def _on_certificate_invalidated(self, event: ops.EventBase) -> None: + logger.warning("Certificate invalidated, requesting new ones.") + if ( + self.i_am_allowed() + and (relation := self.model.get_relation(self.relation_name)) + and relation.active + ): + self._request_certificates(renew=True) + self.callback_f(event) - def _get_csr_from_relation_unit_data(self) -> Optional[str]: - certificate_relations = list(self.model.relations[self.relation_name]) - if not certificate_relations: - return None + def _on_all_certificate_invalidated(self, event: ops.EventBase) -> None: + logger.warning( + "Certificates invalidated, most likely a relation broken." + ) + self.status.set(BlockedStatus("Certificates invalidated")) + if self.i_am_allowed(): + for name in self.key_names(): + self.store.delete_csr(name) + self.callback_f(event) - # unit_data format: - # {"certificate_signing_requests": "['certificate_signing_request': 'CSRTEXT']"} - unit_data = certificate_relations[0].data[self.charm.model.unit] - csr = json.loads(unit_data.get("certificate_signing_requests", "[]")) - if not csr: - return None - - csr = csr[0].get("certificate_signing_request", None) - return csr - - def _get_cert_from_relation_data(self, csr: str) -> dict: - certificate_relations = list(self.model.relations[self.relation_name]) - if not certificate_relations: - return {} - - # app data format: - # {"certificates": "['certificate_signing_request': 'CSR', - # 'certificate': 'CERT', 'ca': 'CA', 'chain': 'CHAIN']"} - certs = certificate_relations[0].data[certificate_relations[0].app] - certs = json.loads(certs.get("certificates", "[]")) - for certificate in certs: - csr_from_app = certificate.get("certificate_signing_request", "") - if csr.strip() == csr_from_app.strip(): - return { - "cert": certificate.get("certificate", None), - "ca": certificate.get("ca", None), - "chain": certificate.get("chain", []), - } - - return {} + def get_certs(self) -> list: + """Return certificates.""" + # If certificates are managed at the app level + # return all the certificates + if self.app_managed_certificates(): + return self.interface.get_provider_certificates() + # If the certificates are managed at the unit level + # return the certificates for the unit + return self.interface.get_assigned_certificates() @property def ready(self) -> bool: """Whether handler ready for use.""" - csr_from_unit = self._get_csr_from_relation_unit_data() - if not csr_from_unit: - return False + certs = self.get_certs() - certs = self._get_cert_from_relation_data(csr_from_unit) - return True if certs else False + if len(certs) != len(self.key_names()): + return False + return True def context(self) -> dict: """Certificates context.""" - csr_from_unit = self._get_csr_from_relation_unit_data() - if not csr_from_unit: + certs = self.get_certs() + if len(certs) != len(self.key_names()): return {} - - certs = self._get_cert_from_relation_data(csr_from_unit) - cert = certs["cert"] - ca_cert = certs["ca"] + "\n" + "\n".join(certs["chain"]) - - ctxt = { - "key": self.private_key, - "cert": cert, - "ca_cert": ca_cert, - } + ctxt = {} + for name, entry in self.store.get_entries().items(): + csr = entry.get("csr") + key = self._private_keys.get(name) + if csr is None or key is None: + logger.warning("Tls Store Entry %s is incomplete", name) + continue + for cert in certs: + if cert.csr == csr: + ctxt.update( + { + "key_" + name: key, + "ca_cert_" + + name: cert.ca + + "\n" + + "\n".join(cert.chain), + "cert_" + name: cert.certificate, + } + ) + else: + logger.debug("No certificate found for CSR %s", name) return ctxt diff --git a/ops-sunbeam/ops_sunbeam/test_utils.py b/ops-sunbeam/ops_sunbeam/test_utils.py index 0baab920..7d310f70 100644 --- a/ops-sunbeam/ops_sunbeam/test_utils.py +++ b/ops-sunbeam/ops_sunbeam/test_utils.py @@ -589,6 +589,7 @@ def add_base_certificates_relation(harness: Harness) -> str: "certificate_signing_requests": json.dumps([csr]), }, ) + harness.charm.certs.store.set_csr("main", TEST_CSR.encode()) return rel_id