summit/backend/venv/lib/python3.12/site-packages/asyncpg/cluster.py

730 lines
24 KiB
Python

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import os
import os.path
import platform
import random
import re
import shutil
import socket
import string
import subprocess
import sys
import tempfile
import textwrap
import time
import asyncpg
from asyncpg import serverversion
_system = platform.uname().system
if _system == 'Windows':
def platform_exe(name):
if name.endswith('.exe'):
return name
return name + '.exe'
else:
def platform_exe(name):
return name
def find_available_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]
except Exception:
return None
finally:
sock.close()
def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
name = "".join(random.choices(string.ascii_lowercase, k=8))
if dir is None:
dir = tempfile.gettempdir()
if prefix is None:
prefix = tempfile.gettempprefix()
if suffix is None:
suffix = ""
fn = os.path.join(dir, prefix + name + suffix)
os.mkdir(fn, 0o755)
return fn
def _mkdtemp(suffix=None, prefix=None, dir=None):
if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
# Due to mitigations introduced in python/cpython#118486
# when Python runs in a session created via an SSH connection
# tempfile.mkdtemp creates directories that are not accessible.
return _world_readable_mkdtemp(suffix, prefix, dir)
else:
return tempfile.mkdtemp(suffix, prefix, dir)
class ClusterError(Exception):
pass
class Cluster:
def __init__(self, data_dir, *, pg_config_path=None):
self._data_dir = data_dir
self._pg_config_path = pg_config_path
self._pg_bin_dir = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
self._connection_addr = None
self._connection_spec_override = None
def get_pg_version(self):
return self._pg_version
def is_managed(self):
return True
def get_data_dir(self):
return self._data_dir
def get_status(self):
if self._pg_ctl is None:
self._init_env()
process = subprocess.run(
[self._pg_ctl, 'status', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
not os.listdir(self._data_dir)):
return 'not-initialized'
elif process.returncode == 3:
return 'stopped'
elif process.returncode == 0:
r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
if not r:
raise ClusterError(
'could not parse pg_ctl status output: {}'.format(
stdout.decode()))
self._daemon_pid = int(r.group(1))
return self._test_connection(timeout=0)
else:
raise ClusterError(
'pg_ctl status exited with status {:d}: {}'.format(
process.returncode, stderr))
async def connect(self, loop=None, **kwargs):
conn_info = self.get_connection_spec()
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
settings = dict(settings)
if 'encoding' not in settings:
settings['encoding'] = 'UTF-8'
if settings:
settings_args = ['--{}={}'.format(k, v)
for k, v in settings.items()]
extra_args = ['-o'] + [' '.join(settings_args)]
else:
extra_args = []
os.makedirs(self._data_dir, exist_ok=True)
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_ctl init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
"""Start the cluster."""
status = self.get_status()
if status == 'running':
return
elif status == 'not-initialized':
raise ClusterError(
'cluster in {!r} has not been initialized'.format(
self._data_dir))
port = opts.pop('port', None)
if port == 'dynamic':
port = find_available_port()
extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
extra_args.append('--port={}'.format(port))
sockdir = server_settings.get('unix_socket_directories')
if sockdir is None:
sockdir = server_settings.get('unix_socket_directory')
if sockdir is None and _system != 'Windows':
sockdir = tempfile.gettempdir()
ssl_key = server_settings.get('ssl_key_file')
if ssl_key:
# Make sure server certificate key file has correct permissions.
keyfile = os.path.join(self._data_dir, 'srvkey.pem')
shutil.copy(ssl_key, keyfile)
os.chmod(keyfile, 0o600)
server_settings = server_settings.copy()
server_settings['ssl_key_file'] = keyfile
if sockdir is not None:
if self._pg_version < (9, 3):
sockdir_opt = 'unix_socket_directory'
else:
sockdir_opt = 'unix_socket_directories'
server_settings[sockdir_opt] = sockdir
for k, v in server_settings.items():
extra_args.extend(['-c', '{}={}'.format(k, v)])
if _system == 'Windows':
# On Windows we have to use pg_ctl as direct execution
# of postgres daemon under an Administrative account
# is not permitted and there is no easy way to drop
# privileges.
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
print(
'asyncpg.cluster: Running',
' '.join([
self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)
]),
file=sys.stderr,
)
else:
stdout = subprocess.DEVNULL
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
if process.returncode != 0:
if process.stderr:
stderr = ':\n{}'.format(process.stderr.decode())
else:
stderr = ''
raise ClusterError(
'pg_ctl start exited with status {:d}{}'.format(
process.returncode, stderr))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
self._daemon_pid = self._daemon_process.pid
self._test_connection(timeout=wait)
def reload(self):
"""Reload server configuration."""
status = self.get_status()
if status != 'running':
raise ClusterError('cannot reload: cluster is not running')
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
if (self._daemon_process is not None and
self._daemon_process.returncode is None):
self._daemon_process.kill()
def destroy(self):
status = self.get_status()
if status == 'stopped' or status == 'not-initialized':
shutil.rmtree(self._data_dir)
else:
raise ClusterError('cannot destroy {} cluster'.format(status))
def _get_connection_spec(self):
if self._connection_addr is None:
self._connection_addr = self._connection_addr_from_pidfile()
if self._connection_addr is not None:
if self._connection_spec_override:
args = self._connection_addr.copy()
args.update(self._connection_spec_override)
return args
else:
return self._connection_addr
def get_connection_spec(self):
status = self.get_status()
if status != 'running':
raise ClusterError('cluster is not running')
return self._get_connection_spec()
def override_connection_spec(self, **kwargs):
self._connection_spec_override = kwargs
def reset_wal(self, *, oid=None, xid=None):
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify WAL status: cluster is not initialized')
if status == 'running':
raise ClusterError(
'cannot modify WAL status: cluster is running')
opts = []
if oid is not None:
opts.extend(['-o', str(oid)])
if xid is not None:
opts.extend(['-x', str(xid)])
if not opts:
return
opts.append(self._data_dir)
try:
reset_wal = self._find_pg_binary('pg_resetwal')
except ClusterError:
reset_wal = self._find_pg_binary('pg_resetxlog')
process = subprocess.run(
[reset_wal] + opts,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_resetwal exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def reset_hba(self):
"""Remove all records from pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
try:
with open(pg_hba, 'w'):
pass
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
"""Add a record to pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
raise ValueError('invalid HBA record type: {!r}'.format(type))
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
record = '{} {} {}'.format(type, database, user)
if type != 'local':
if address is None:
raise ValueError(
'{!r} entry requires a valid address'.format(type))
else:
record += ' {}'.format(address)
record += ' {}'.format(auth_method)
if auth_options is not None:
record += ' ' + ' '.join(
'{}={}'.format(k, v) for k, v in auth_options)
try:
with open(pg_hba, 'a') as f:
print(record, file=f)
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def trust_local_connections(self):
self.reset_hba()
if _system != 'Windows':
self.add_hba_entry(type='local', database='all',
user='all', auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='all', user='all',
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='all', user='all',
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def trust_local_replication_by(self, user):
if _system != 'Windows':
self.add_hba_entry(type='local', database='replication',
user=user, auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='replication', user=user,
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='replication', user=user,
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def _init_env(self):
if not self._pg_bin_dir:
pg_config = self._find_pg_config(self._pg_config_path)
pg_config_data = self._run_pg_config(pg_config)
self._pg_bin_dir = pg_config_data.get('bindir')
if not self._pg_bin_dir:
raise ClusterError(
'pg_config output did not provide the BINDIR value')
self._pg_ctl = self._find_pg_binary('pg_ctl')
self._postgres = self._find_pg_binary('postgres')
self._pg_version = self._get_pg_version()
def _connection_addr_from_pidfile(self):
pidfile = os.path.join(self._data_dir, 'postmaster.pid')
try:
with open(pidfile, 'rt') as f:
piddata = f.read()
except FileNotFoundError:
return None
lines = piddata.splitlines()
if len(lines) < 6:
# A complete postgres pidfile is at least 6 lines
return None
pmpid = int(lines[0])
if self._daemon_pid and pmpid != self._daemon_pid:
# This might be an old pidfile left from previous postgres
# daemon run.
return None
portnum = lines[3]
sockdir = lines[4]
hostaddr = lines[5]
if sockdir:
if sockdir[0] != '/':
# Relative sockdir
sockdir = os.path.normpath(
os.path.join(self._data_dir, sockdir))
host_str = sockdir
else:
host_str = hostaddr
if host_str == '*':
host_str = 'localhost'
elif host_str == '0.0.0.0':
host_str = '127.0.0.1'
elif host_str == '::':
host_str = '::1'
return {
'host': host_str,
'port': portnum
}
def _test_connection(self, timeout=60):
self._connection_addr = None
loop = asyncio.new_event_loop()
try:
for i in range(timeout):
if self._connection_addr is None:
conn_spec = self._get_connection_spec()
if conn_spec is None:
time.sleep(1)
continue
try:
con = loop.run_until_complete(
asyncpg.connect(database='postgres',
user='postgres',
timeout=5, loop=loop,
**self._connection_addr))
except (OSError, asyncio.TimeoutError,
asyncpg.CannotConnectNowError,
asyncpg.PostgresConnectionError):
time.sleep(1)
continue
except asyncpg.PostgresError:
# Any other error other than ServerNotReadyError or
# ConnectionError is interpreted to indicate the server is
# up.
break
else:
loop.run_until_complete(con.close())
break
finally:
loop.close()
return 'running'
def _run_pg_config(self, pg_config_path):
process = subprocess.run(
pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError('pg_config exited with status {:d}: {}'.format(
process.returncode, stderr))
else:
config = {}
for line in stdout.splitlines():
k, eq, v = line.decode('utf-8').partition('=')
if eq:
config[k.strip().lower()] = v.strip()
return config
def _find_pg_config(self, pg_config_path):
if pg_config_path is None:
pg_install = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
if pg_install:
pg_config_path = platform_exe(
os.path.join(pg_install, 'pg_config'))
else:
pathenv = os.environ.get('PATH').split(os.pathsep)
for path in pathenv:
pg_config_path = platform_exe(
os.path.join(path, 'pg_config'))
if os.path.exists(pg_config_path):
break
else:
pg_config_path = None
if not pg_config_path:
raise ClusterError('could not find pg_config executable')
if not os.path.isfile(pg_config_path):
raise ClusterError('{!r} is not an executable'.format(
pg_config_path))
return pg_config_path
def _find_pg_binary(self, binary):
bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
if not os.path.isfile(bpath):
raise ClusterError(
'could not find {} executable: '.format(binary) +
'{!r} does not exist or is not a file'.format(bpath))
return bpath
def _get_pg_version(self):
process = subprocess.run(
[self._postgres, '--version'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError(
'postgres --version exited with status {:d}: {}'.format(
process.returncode, stderr))
version_string = stdout.decode('utf-8').strip(' \n')
prefix = 'postgres (PostgreSQL) '
if not version_string.startswith(prefix):
raise ClusterError(
'could not determine server version from {!r}'.format(
version_string))
version_string = version_string[len(prefix):]
return serverversion.split_server_version_string(version_string)
class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._data_dir = _mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
class HotStandbyCluster(TempCluster):
def __init__(self, *,
master, replication_user,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._master = master
self._repl_user = replication_user
super().__init__(
data_dir_suffix=data_dir_suffix,
data_dir_prefix=data_dir_prefix,
data_dir_parent=data_dir_parent,
pg_config_path=pg_config_path)
def _init_env(self):
super()._init_env()
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
process = subprocess.run(
[self._pg_basebackup, '-h', self._master['host'],
'-p', self._master['port'], '-D', self._data_dir,
'-U', self._repl_user],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
primary_conninfo = 'host={host} port={port} user={user}'
""".format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user)))
else:
f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
f.close()
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
if self._pg_version >= (12, 0):
server_settings = server_settings.copy()
server_settings['primary_conninfo'] = (
'"host={host} port={port} user={user}"'.format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user,
)
)
super().start(wait=wait, server_settings=server_settings, **opts)
class RunningCluster(Cluster):
def __init__(self, **kwargs):
self.conn_spec = kwargs
def is_managed(self):
return False
def get_connection_spec(self):
return dict(self.conn_spec)
def get_status(self):
return 'running'
def init(self, **settings):
pass
def start(self, wait=60, **settings):
pass
def stop(self, wait=60):
pass
def destroy(self):
pass
def reset_hba(self):
raise ClusterError('cannot modify HBA records of unmanaged cluster')
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
raise ClusterError('cannot modify HBA records of unmanaged cluster')