import asyncio import logging from abc import ABC, abstractmethod from enum import Enum from typing import List, Optional, Tuple, Union from redis.asyncio import Redis from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper from redis.backoff import NoBackoff from redis.http.http_client import HttpClient from redis.multidb.exception import UnhealthyDatabaseException from redis.retry import Retry DEFAULT_HEALTH_CHECK_PROBES = 3 DEFAULT_HEALTH_CHECK_INTERVAL = 5 DEFAULT_HEALTH_CHECK_DELAY = 0.5 DEFAULT_LAG_AWARE_TOLERANCE = 5000 logger = logging.getLogger(__name__) class HealthCheck(ABC): @abstractmethod async def check_health(self, database) -> bool: """Function to determine the health status.""" pass class HealthCheckPolicy(ABC): """ Health checks execution policy. """ @property @abstractmethod def health_check_probes(self) -> int: """Number of probes to execute health checks.""" pass @property @abstractmethod def health_check_delay(self) -> float: """Delay between health check probes.""" pass @abstractmethod async def execute(self, health_checks: List[HealthCheck], database) -> bool: """Execute health checks and return database health status.""" pass class AbstractHealthCheckPolicy(HealthCheckPolicy): def __init__(self, health_check_probes: int, health_check_delay: float): if health_check_probes < 1: raise ValueError("health_check_probes must be greater than 0") self._health_check_probes = health_check_probes self._health_check_delay = health_check_delay @property def health_check_probes(self) -> int: return self._health_check_probes @property def health_check_delay(self) -> float: return self._health_check_delay @abstractmethod async def execute(self, health_checks: List[HealthCheck], database) -> bool: pass class HealthyAllPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if all health check probes are successful. """ def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) async def execute(self, health_checks: List[HealthCheck], database) -> bool: for health_check in health_checks: for attempt in range(self.health_check_probes): try: if not await health_check.check_health(database): return False except Exception as e: raise UnhealthyDatabaseException("Unhealthy database", database, e) if attempt < self.health_check_probes - 1: await asyncio.sleep(self._health_check_delay) return True class HealthyMajorityPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if a majority of health check probes are successful. """ def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) async def execute(self, health_checks: List[HealthCheck], database) -> bool: for health_check in health_checks: if self.health_check_probes % 2 == 0: allowed_unsuccessful_probes = self.health_check_probes / 2 else: allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2 for attempt in range(self.health_check_probes): try: if not await health_check.check_health(database): allowed_unsuccessful_probes -= 1 if allowed_unsuccessful_probes <= 0: return False except Exception as e: allowed_unsuccessful_probes -= 1 if allowed_unsuccessful_probes <= 0: raise UnhealthyDatabaseException( "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: await asyncio.sleep(self._health_check_delay) return True class HealthyAnyPolicy(AbstractHealthCheckPolicy): """ Policy that returns True if at least one health check probe is successful. """ def __init__(self, health_check_probes: int, health_check_delay: float): super().__init__(health_check_probes, health_check_delay) async def execute(self, health_checks: List[HealthCheck], database) -> bool: is_healthy = False for health_check in health_checks: exception = None for attempt in range(self.health_check_probes): try: if await health_check.check_health(database): is_healthy = True break else: is_healthy = False except Exception as e: exception = UnhealthyDatabaseException( "Unhealthy database", database, e ) if attempt < self.health_check_probes - 1: await asyncio.sleep(self._health_check_delay) if not is_healthy and not exception: return is_healthy elif not is_healthy and exception: raise exception return is_healthy class HealthCheckPolicies(Enum): HEALTHY_ALL = HealthyAllPolicy HEALTHY_MAJORITY = HealthyMajorityPolicy HEALTHY_ANY = HealthyAnyPolicy DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL class PingHealthCheck(HealthCheck): """ Health check based on PING command. """ async def check_health(self, database) -> bool: if isinstance(database.client, Redis): return await database.client.execute_command("PING") else: # For a cluster checks if all nodes are healthy. all_nodes = database.client.get_nodes() for node in all_nodes: if not await node.redis_connection.execute_command("PING"): return False return True class LagAwareHealthCheck(HealthCheck): """ Health check available for Redis Enterprise deployments. Verify via REST API that the database is healthy based on different lags. """ def __init__( self, rest_api_port: int = 9443, lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, # TLS verification (server) options ca_file: Optional[str] = None, ca_path: Optional[str] = None, ca_data: Optional[Union[str, bytes]] = None, # Mutual TLS (client cert) options client_cert_file: Optional[str] = None, client_key_file: Optional[str] = None, client_key_password: Optional[str] = None, ): """ Initialize LagAwareHealthCheck with the specified parameters. Args: rest_api_port: Port number for Redis Enterprise REST API (default: 9443) lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) auth_basic: Tuple of (username, password) for basic authentication verify_tls: Whether to verify TLS certificates (default: True) ca_file: Path to CA certificate file for TLS verification ca_path: Path to CA certificates directory for TLS verification ca_data: CA certificate data as string or bytes client_cert_file: Path to client certificate file for mutual TLS client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ self._http_client = AsyncHTTPClientWrapper( HttpClient( timeout=timeout, auth_basic=auth_basic, retry=Retry(NoBackoff(), retries=0), verify_tls=verify_tls, ca_file=ca_file, ca_path=ca_path, ca_data=ca_data, client_cert_file=client_cert_file, client_key_file=client_key_file, client_key_password=client_key_password, ) ) self._rest_api_port = rest_api_port self._lag_aware_tolerance = lag_aware_tolerance async def check_health(self, database) -> bool: if database.health_check_url is None: raise ValueError( "Database health check url is not set. Please check DatabaseConfig for the current database." ) if isinstance(database.client, Redis): db_host = database.client.get_connection_kwargs()["host"] else: db_host = database.client.startup_nodes[0].host base_url = f"{database.health_check_url}:{self._rest_api_port}" self._http_client.client.base_url = base_url # Find bdb matching to the current database host matching_bdb = None for bdb in await self._http_client.get("/v1/bdbs"): for endpoint in bdb["endpoints"]: if endpoint["dns_name"] == db_host: matching_bdb = bdb break # In case if the host was set as public IP for addr in endpoint["addr"]: if addr == db_host: matching_bdb = bdb break if matching_bdb is None: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") url = ( f"/v1/bdbs/{matching_bdb['uid']}/availability" f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}" ) await self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised return True