286 lines
10 KiB
Python
286 lines
10 KiB
Python
import asyncio
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from redis.asyncio import Redis
|
|
from redis.asyncio.http.http_client import DEFAULT_TIMEOUT, AsyncHTTPClientWrapper
|
|
from redis.backoff import NoBackoff
|
|
from redis.http.http_client import HttpClient
|
|
from redis.multidb.exception import UnhealthyDatabaseException
|
|
from redis.retry import Retry
|
|
|
|
DEFAULT_HEALTH_CHECK_PROBES = 3
|
|
DEFAULT_HEALTH_CHECK_INTERVAL = 5
|
|
DEFAULT_HEALTH_CHECK_DELAY = 0.5
|
|
DEFAULT_LAG_AWARE_TOLERANCE = 5000
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HealthCheck(ABC):
|
|
@abstractmethod
|
|
async def check_health(self, database) -> bool:
|
|
"""Function to determine the health status."""
|
|
pass
|
|
|
|
|
|
class HealthCheckPolicy(ABC):
|
|
"""
|
|
Health checks execution policy.
|
|
"""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def health_check_probes(self) -> int:
|
|
"""Number of probes to execute health checks."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def health_check_delay(self) -> float:
|
|
"""Delay between health check probes."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
"""Execute health checks and return database health status."""
|
|
pass
|
|
|
|
|
|
class AbstractHealthCheckPolicy(HealthCheckPolicy):
|
|
def __init__(self, health_check_probes: int, health_check_delay: float):
|
|
if health_check_probes < 1:
|
|
raise ValueError("health_check_probes must be greater than 0")
|
|
self._health_check_probes = health_check_probes
|
|
self._health_check_delay = health_check_delay
|
|
|
|
@property
|
|
def health_check_probes(self) -> int:
|
|
return self._health_check_probes
|
|
|
|
@property
|
|
def health_check_delay(self) -> float:
|
|
return self._health_check_delay
|
|
|
|
@abstractmethod
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
pass
|
|
|
|
|
|
class HealthyAllPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if all health check probes are successful.
|
|
"""
|
|
|
|
def __init__(self, health_check_probes: int, health_check_delay: float):
|
|
super().__init__(health_check_probes, health_check_delay)
|
|
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
for health_check in health_checks:
|
|
for attempt in range(self.health_check_probes):
|
|
try:
|
|
if not await health_check.check_health(database):
|
|
return False
|
|
except Exception as e:
|
|
raise UnhealthyDatabaseException("Unhealthy database", database, e)
|
|
|
|
if attempt < self.health_check_probes - 1:
|
|
await asyncio.sleep(self._health_check_delay)
|
|
return True
|
|
|
|
|
|
class HealthyMajorityPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if a majority of health check probes are successful.
|
|
"""
|
|
|
|
def __init__(self, health_check_probes: int, health_check_delay: float):
|
|
super().__init__(health_check_probes, health_check_delay)
|
|
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
for health_check in health_checks:
|
|
if self.health_check_probes % 2 == 0:
|
|
allowed_unsuccessful_probes = self.health_check_probes / 2
|
|
else:
|
|
allowed_unsuccessful_probes = (self.health_check_probes + 1) / 2
|
|
|
|
for attempt in range(self.health_check_probes):
|
|
try:
|
|
if not await health_check.check_health(database):
|
|
allowed_unsuccessful_probes -= 1
|
|
if allowed_unsuccessful_probes <= 0:
|
|
return False
|
|
except Exception as e:
|
|
allowed_unsuccessful_probes -= 1
|
|
if allowed_unsuccessful_probes <= 0:
|
|
raise UnhealthyDatabaseException(
|
|
"Unhealthy database", database, e
|
|
)
|
|
|
|
if attempt < self.health_check_probes - 1:
|
|
await asyncio.sleep(self._health_check_delay)
|
|
return True
|
|
|
|
|
|
class HealthyAnyPolicy(AbstractHealthCheckPolicy):
|
|
"""
|
|
Policy that returns True if at least one health check probe is successful.
|
|
"""
|
|
|
|
def __init__(self, health_check_probes: int, health_check_delay: float):
|
|
super().__init__(health_check_probes, health_check_delay)
|
|
|
|
async def execute(self, health_checks: List[HealthCheck], database) -> bool:
|
|
is_healthy = False
|
|
|
|
for health_check in health_checks:
|
|
exception = None
|
|
|
|
for attempt in range(self.health_check_probes):
|
|
try:
|
|
if await health_check.check_health(database):
|
|
is_healthy = True
|
|
break
|
|
else:
|
|
is_healthy = False
|
|
except Exception as e:
|
|
exception = UnhealthyDatabaseException(
|
|
"Unhealthy database", database, e
|
|
)
|
|
|
|
if attempt < self.health_check_probes - 1:
|
|
await asyncio.sleep(self._health_check_delay)
|
|
|
|
if not is_healthy and not exception:
|
|
return is_healthy
|
|
elif not is_healthy and exception:
|
|
raise exception
|
|
|
|
return is_healthy
|
|
|
|
|
|
class HealthCheckPolicies(Enum):
|
|
HEALTHY_ALL = HealthyAllPolicy
|
|
HEALTHY_MAJORITY = HealthyMajorityPolicy
|
|
HEALTHY_ANY = HealthyAnyPolicy
|
|
|
|
|
|
DEFAULT_HEALTH_CHECK_POLICY: HealthCheckPolicies = HealthCheckPolicies.HEALTHY_ALL
|
|
|
|
|
|
class PingHealthCheck(HealthCheck):
|
|
"""
|
|
Health check based on PING command.
|
|
"""
|
|
|
|
async def check_health(self, database) -> bool:
|
|
if isinstance(database.client, Redis):
|
|
return await database.client.execute_command("PING")
|
|
else:
|
|
# For a cluster checks if all nodes are healthy.
|
|
all_nodes = database.client.get_nodes()
|
|
for node in all_nodes:
|
|
if not await node.redis_connection.execute_command("PING"):
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
class LagAwareHealthCheck(HealthCheck):
|
|
"""
|
|
Health check available for Redis Enterprise deployments.
|
|
Verify via REST API that the database is healthy based on different lags.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
rest_api_port: int = 9443,
|
|
lag_aware_tolerance: int = DEFAULT_LAG_AWARE_TOLERANCE,
|
|
timeout: float = DEFAULT_TIMEOUT,
|
|
auth_basic: Optional[Tuple[str, str]] = None,
|
|
verify_tls: bool = True,
|
|
# TLS verification (server) options
|
|
ca_file: Optional[str] = None,
|
|
ca_path: Optional[str] = None,
|
|
ca_data: Optional[Union[str, bytes]] = None,
|
|
# Mutual TLS (client cert) options
|
|
client_cert_file: Optional[str] = None,
|
|
client_key_file: Optional[str] = None,
|
|
client_key_password: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize LagAwareHealthCheck with the specified parameters.
|
|
|
|
Args:
|
|
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
|
|
lag_aware_tolerance: Tolerance in lag between databases in MS (default: 100)
|
|
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
|
|
auth_basic: Tuple of (username, password) for basic authentication
|
|
verify_tls: Whether to verify TLS certificates (default: True)
|
|
ca_file: Path to CA certificate file for TLS verification
|
|
ca_path: Path to CA certificates directory for TLS verification
|
|
ca_data: CA certificate data as string or bytes
|
|
client_cert_file: Path to client certificate file for mutual TLS
|
|
client_key_file: Path to client private key file for mutual TLS
|
|
client_key_password: Password for encrypted client private key
|
|
"""
|
|
self._http_client = AsyncHTTPClientWrapper(
|
|
HttpClient(
|
|
timeout=timeout,
|
|
auth_basic=auth_basic,
|
|
retry=Retry(NoBackoff(), retries=0),
|
|
verify_tls=verify_tls,
|
|
ca_file=ca_file,
|
|
ca_path=ca_path,
|
|
ca_data=ca_data,
|
|
client_cert_file=client_cert_file,
|
|
client_key_file=client_key_file,
|
|
client_key_password=client_key_password,
|
|
)
|
|
)
|
|
self._rest_api_port = rest_api_port
|
|
self._lag_aware_tolerance = lag_aware_tolerance
|
|
|
|
async def check_health(self, database) -> bool:
|
|
if database.health_check_url is None:
|
|
raise ValueError(
|
|
"Database health check url is not set. Please check DatabaseConfig for the current database."
|
|
)
|
|
|
|
if isinstance(database.client, Redis):
|
|
db_host = database.client.get_connection_kwargs()["host"]
|
|
else:
|
|
db_host = database.client.startup_nodes[0].host
|
|
|
|
base_url = f"{database.health_check_url}:{self._rest_api_port}"
|
|
self._http_client.client.base_url = base_url
|
|
|
|
# Find bdb matching to the current database host
|
|
matching_bdb = None
|
|
for bdb in await self._http_client.get("/v1/bdbs"):
|
|
for endpoint in bdb["endpoints"]:
|
|
if endpoint["dns_name"] == db_host:
|
|
matching_bdb = bdb
|
|
break
|
|
|
|
# In case if the host was set as public IP
|
|
for addr in endpoint["addr"]:
|
|
if addr == db_host:
|
|
matching_bdb = bdb
|
|
break
|
|
|
|
if matching_bdb is None:
|
|
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
|
|
raise ValueError("Could not find a matching bdb")
|
|
|
|
url = (
|
|
f"/v1/bdbs/{matching_bdb['uid']}/availability"
|
|
f"?extend_check=lag&availability_lag_tolerance_ms={self._lag_aware_tolerance}"
|
|
)
|
|
await self._http_client.get(url, expect_json=False)
|
|
|
|
# Status checked in an http client, otherwise HttpError will be raised
|
|
return True
|