summit/backend/venv/lib/python3.12/site-packages/asyncpg/connect_utils.py

1314 lines
42 KiB
Python

# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import asyncio
import configparser
import collections
from collections.abc import Callable
import enum
import functools
import getpass
import os
import pathlib
import platform
import random
import re
import socket
import ssl as ssl_module
import stat
import struct
import sys
import typing
import urllib.parse
import warnings
import inspect
from . import compat
from . import exceptions
from . import protocol
class SSLMode(enum.IntEnum):
disable = 0
allow = 1
prefer = 2
require = 3
verify_ca = 4
verify_full = 5
@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))
class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'sslmode',
'ssl_negotiation',
'server_settings',
'target_session_attrs',
'krbsrvname',
'gsslib',
])
_ClientConfiguration = collections.namedtuple(
'ConnectionConfiguration',
[
'command_timeout',
'statement_cache_size',
'max_cached_statement_lifetime',
'max_cacheable_statement_size',
])
_system = platform.uname().system
if _system == 'Windows':
PGPASSFILE = 'pgpass.conf'
else:
PGPASSFILE = '.pgpass'
PG_SERVICEFILE = '.pg_service.conf'
def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:
passtab = []
try:
if not passfile.exists():
return []
if not passfile.is_file():
warnings.warn(
'password file {!r} is not a plain file'.format(passfile))
return []
if _system != 'Windows':
if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
warnings.warn(
'password file {!r} has group or world access; '
'permissions should be u=rw (0600) or less'.format(
passfile))
return []
with passfile.open('rt') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
# Skip empty lines and comments.
continue
# Backslash escapes both itself and the colon,
# which is a record separator.
line = line.replace(R'\\', '\n')
passtab.append(tuple(
p.replace('\n', R'\\')
for p in re.split(r'(?<!\\):', line, maxsplit=4)
))
except IOError:
pass
return passtab
def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
"""Parse the pgpass file and return the matching password.
:return:
Password string, if found, ``None`` otherwise.
"""
passtab = _read_password_file(passfile)
if not passtab:
return None
for host, port in zip(hosts, ports):
if host.startswith('/'):
# Unix sockets get normalized into 'localhost'
host = 'localhost'
for phost, pport, pdatabase, puser, ppassword in passtab:
if phost != '*' and phost != host:
continue
if pport != '*' and pport != str(port):
continue
if pdatabase != '*' and pdatabase != database:
continue
if puser != '*' and puser != user:
continue
# Found a match.
return ppassword
return None
def _validate_port_spec(hosts, port):
if isinstance(port, list) and len(port) > 1:
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.ClientConfigurationError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
elif isinstance(port, list) and len(port) == 1:
port = [port[0] for _ in range(len(hosts))]
else:
port = [port for _ in range(len(hosts))]
return port
def _parse_hostlist(hostlist, port, *, unquote=False):
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]
hosts = []
hostlist_ports = []
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
default_port = [int(p) for p in portspec.split(',')]
else:
default_port = int(portspec)
else:
default_port = 5432
default_port = _validate_port_spec(hostspecs, default_port)
else:
port = _validate_port_spec(hostspecs, port)
for i, hostspec in enumerate(hostspecs):
if hostspec[0] == '/':
# Unix socket
addr = hostspec
hostspec_port = ''
elif hostspec[0] == '[':
# IPv6 address
m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
if m:
addr = m.group(1)
hostspec_port = m.group(2)
else:
raise exceptions.ClientConfigurationError(
'invalid IPv6 address in the connection URI: {!r}'.format(
hostspec
)
)
else:
# IPv4 address
addr, _, hostspec_port = hostspec.partition(':')
if unquote:
addr = urllib.parse.unquote(addr)
hosts.append(addr)
if not port:
if hostspec_port:
if unquote:
hostspec_port = urllib.parse.unquote(hostspec_port)
hostlist_ports.append(int(hostspec_port))
else:
hostlist_ports.append(default_port[i])
if not port:
port = hostlist_ports
return hosts, port
def _parse_tls_version(tls_version):
if tls_version.startswith('SSL'):
raise exceptions.ClientConfigurationError(
f"Unsupported TLS version: {tls_version}"
)
try:
return ssl_module.TLSVersion[tls_version.replace('.', '_')]
except KeyError:
raise exceptions.ClientConfigurationError(
f"No such TLS version: {tls_version}"
)
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
try:
homedir = pathlib.Path.home()
except (RuntimeError, KeyError):
return None
return (homedir / '.postgresql' / filename).resolve()
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
service, servicefile,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
ssl_min_protocol_version = ssl_max_protocol_version = None
sslnegotiation = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
query = None
if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
if 'service' in query:
val = query.pop('service')
if not service and val:
service = val
connection_service_file = servicefile
if connection_service_file is None:
connection_service_file = os.getenv('PGSERVICEFILE')
if connection_service_file is None:
homedir = compat.get_pg_home_directory()
if homedir:
connection_service_file = homedir / PG_SERVICEFILE
else:
connection_service_file = None
else:
connection_service_file = pathlib.Path(connection_service_file)
if parsed.scheme not in {'postgresql', 'postgres'}:
raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
if parsed.netloc:
if '@' in parsed.netloc:
dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@')
else:
dsn_hostspec = parsed.netloc
dsn_auth = ''
else:
dsn_auth = dsn_hostspec = ''
if dsn_auth:
dsn_user, _, dsn_password = dsn_auth.partition(':')
else:
dsn_user = dsn_password = ''
if not host and dsn_hostspec:
host, port = _parse_hostlist(dsn_hostspec, port, unquote=True)
if parsed.path and database is None:
dsn_database = parsed.path
if dsn_database.startswith('/'):
dsn_database = dsn_database[1:]
database = urllib.parse.unquote(dsn_database)
if user is None and dsn_user:
user = urllib.parse.unquote(dsn_user)
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)
if query:
if 'port' in query:
val = query.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]
if 'host' in query:
val = query.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)
if 'dbname' in query:
val = query.pop('dbname')
if database is None:
database = val
if 'database' in query:
val = query.pop('database')
if database is None:
database = val
if 'user' in query:
val = query.pop('user')
if user is None:
user = val
if 'password' in query:
val = query.pop('password')
if password is None:
password = val
if 'passfile' in query:
val = query.pop('passfile')
if passfile is None:
passfile = val
if 'sslmode' in query:
val = query.pop('sslmode')
if ssl is None:
ssl = val
if 'sslcert' in query:
sslcert = query.pop('sslcert')
if 'sslkey' in query:
sslkey = query.pop('sslkey')
if 'sslrootcert' in query:
sslrootcert = query.pop('sslrootcert')
if 'sslnegotiation' in query:
sslnegotiation = query.pop('sslnegotiation')
if 'sslcrl' in query:
sslcrl = query.pop('sslcrl')
if 'sslpassword' in query:
sslpassword = query.pop('sslpassword')
if 'ssl_min_protocol_version' in query:
ssl_min_protocol_version = query.pop(
'ssl_min_protocol_version'
)
if 'ssl_max_protocol_version' in query:
ssl_max_protocol_version = query.pop(
'ssl_max_protocol_version'
)
if 'target_session_attrs' in query:
dsn_target_session_attrs = query.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs
if 'krbsrvname' in query:
val = query.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val
if 'gsslib' in query:
val = query.pop('gsslib')
if gsslib is None:
gsslib = val
if 'service' in query:
val = query.pop('service')
if service is None:
service = val
if query:
if server_settings is None:
server_settings = query
else:
server_settings = {**query, **server_settings}
if connection_service_file is not None and service is not None:
pg_service = configparser.ConfigParser()
pg_service.read(connection_service_file)
if service in pg_service.sections():
service_params = pg_service[service]
if 'port' in service_params:
val = service_params.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]
if 'host' in service_params:
val = service_params.pop('host')
if not host and val:
host, port = _parse_hostlist(val, port)
if 'dbname' in service_params:
val = service_params.pop('dbname')
if database is None:
database = val
if 'database' in service_params:
val = service_params.pop('database')
if database is None:
database = val
if 'user' in service_params:
val = service_params.pop('user')
if user is None:
user = val
if 'password' in service_params:
val = service_params.pop('password')
if password is None:
password = val
if 'passfile' in service_params:
val = service_params.pop('passfile')
if passfile is None:
passfile = val
if 'sslmode' in service_params:
val = service_params.pop('sslmode')
if ssl is None:
ssl = val
if 'sslcert' in service_params:
val = service_params.pop('sslcert')
if sslcert is None:
sslcert = val
if 'sslkey' in service_params:
val = service_params.pop('sslkey')
if sslkey is None:
sslkey = val
if 'sslrootcert' in service_params:
val = service_params.pop('sslrootcert')
if sslrootcert is None:
sslrootcert = val
if 'sslnegotiation' in service_params:
val = service_params.pop('sslnegotiation')
if sslnegotiation is None:
sslnegotiation = val
if 'sslcrl' in service_params:
val = service_params.pop('sslcrl')
if sslcrl is None:
sslcrl = val
if 'sslpassword' in service_params:
val = service_params.pop('sslpassword')
if sslpassword is None:
sslpassword = val
if 'ssl_min_protocol_version' in service_params:
val = service_params.pop(
'ssl_min_protocol_version'
)
if ssl_min_protocol_version is None:
ssl_min_protocol_version = val
if 'ssl_max_protocol_version' in service_params:
val = service_params.pop(
'ssl_max_protocol_version'
)
if ssl_max_protocol_version is None:
ssl_max_protocol_version = val
if 'target_session_attrs' in service_params:
dsn_target_session_attrs = service_params.pop(
'target_session_attrs'
)
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs
if 'krbsrvname' in service_params:
val = service_params.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val
if 'gsslib' in service_params:
val = service_params.pop('gsslib')
if gsslib is None:
gsslib = val
if not service:
service = os.environ.get('PGSERVICE')
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
host, port = _parse_hostlist(hostspec, port)
if not host:
auth_hosts = ['localhost']
if _system == 'Windows':
host = ['localhost']
else:
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']
if not isinstance(host, (list, tuple)):
host = [host]
if auth_hosts is None:
auth_hosts = host
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
port = [int(p) for p in portspec.split(',')]
else:
port = int(portspec)
else:
port = 5432
elif isinstance(port, (list, tuple)):
port = [int(p) for p in port]
else:
port = int(port)
port = _validate_port_spec(host, port)
if user is None:
user = os.getenv('PGUSER')
if not user:
user = getpass.getuser()
if password is None:
password = os.getenv('PGPASSWORD')
if database is None:
database = os.getenv('PGDATABASE')
if database is None:
database = user
if user is None:
raise exceptions.ClientConfigurationError(
'could not determine user name to connect with')
if database is None:
raise exceptions.ClientConfigurationError(
'could not determine database name to connect to')
if password is None:
if passfile is None:
passfile = os.getenv('PGPASSFILE')
if passfile is None:
homedir = compat.get_pg_home_directory()
if homedir:
passfile = homedir / PGPASSFILE
else:
passfile = None
else:
passfile = pathlib.Path(passfile)
if passfile is not None:
password = _read_password_from_pgpass(
hosts=auth_hosts, ports=port,
database=database, user=user,
passfile=passfile)
addrs = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
if '.s.PGSQL.' not in h:
h = os.path.join(h, '.s.PGSQL.{}'.format(p))
addrs.append(h)
else:
# TCP host/port
addrs.append((h, p))
have_tcp_addrs = True
if not addrs:
raise exceptions.InternalClientError(
'could not determine the database address to connect to')
if ssl is None:
ssl = os.getenv('PGSSLMODE')
if ssl is None and have_tcp_addrs:
ssl = 'prefer'
if direct_tls is not None:
sslneg = (
SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
)
else:
if sslnegotiation is None:
sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
if sslnegotiation is not None:
try:
sslneg = SSLNegotiation(sslnegotiation)
except ValueError:
modes = ', '.join(
m.name.replace('_', '-')
for m in SSLNegotiation
)
raise exceptions.ClientConfigurationError(
f'`sslnegotiation` parameter must be one of: {modes}'
) from None
else:
sslneg = SSLNegotiation.postgres
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.ClientConfigurationError(
'`sslmode` parameter must be one of: {}'.format(modes)
) from None
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
if sslmode < SSLMode.allow:
ssl = False
else:
ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
ssl.check_hostname = sslmode >= SSLMode.verify_full
if sslmode < SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
else:
if sslrootcert is None:
sslrootcert = os.getenv('PGSSLROOTCERT')
if sslrootcert:
ssl.load_verify_locations(cafile=sslrootcert)
ssl.verify_mode = ssl_module.CERT_REQUIRED
else:
try:
sslrootcert = _dot_postgresql_path('root.crt')
if sslrootcert is not None:
ssl.load_verify_locations(cafile=sslrootcert)
else:
raise exceptions.ClientConfigurationError(
'cannot determine location of user '
'PostgreSQL configuration directory'
)
except (
exceptions.ClientConfigurationError,
FileNotFoundError,
NotADirectoryError,
):
if sslmode > SSLMode.require:
if sslrootcert is None:
sslrootcert = '~/.postgresql/root.crt'
detail = (
'Could not determine location of user '
'home directory (HOME is either unset, '
'inaccessible, or does not point to a '
'valid directory)'
)
else:
detail = None
raise exceptions.ClientConfigurationError(
f'root certificate file "{sslrootcert}" does '
f'not exist or cannot be accessed',
hint='Provide the certificate file directly '
f'or make sure "{sslrootcert}" '
'exists and is readable.',
detail=detail,
)
elif sslmode == SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
else:
assert False, 'unreachable'
else:
ssl.verify_mode = ssl_module.CERT_REQUIRED
if sslcrl is None:
sslcrl = os.getenv('PGSSLCRL')
if sslcrl:
ssl.load_verify_locations(cafile=sslcrl)
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
else:
sslcrl = _dot_postgresql_path('root.crl')
if sslcrl is not None:
try:
ssl.load_verify_locations(cafile=sslcrl)
except (
FileNotFoundError,
NotADirectoryError,
):
pass
else:
ssl.verify_flags |= \
ssl_module.VERIFY_CRL_CHECK_CHAIN
if sslkey is None:
sslkey = os.getenv('PGSSLKEY')
if not sslkey:
sslkey = _dot_postgresql_path('postgresql.key')
if sslkey is not None and not sslkey.exists():
sslkey = None
if not sslpassword:
sslpassword = ''
if sslcert is None:
sslcert = os.getenv('PGSSLCERT')
if sslcert:
ssl.load_cert_chain(
sslcert, keyfile=sslkey, password=lambda: sslpassword
)
else:
sslcert = _dot_postgresql_path('postgresql.crt')
if sslcert is not None:
try:
ssl.load_cert_chain(
sslcert,
keyfile=sslkey,
password=lambda: sslpassword
)
except (FileNotFoundError, NotADirectoryError):
pass
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
if hasattr(ssl, 'keylog_filename'):
keylogfile = os.environ.get('SSLKEYLOGFILE')
if keylogfile and not sys.flags.ignore_environment:
ssl.keylog_filename = keylogfile
if ssl_min_protocol_version is None:
ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
if ssl_min_protocol_version:
ssl.minimum_version = _parse_tls_version(
ssl_min_protocol_version
)
else:
ssl.minimum_version = _parse_tls_version('TLSv1.2')
if ssl_max_protocol_version is None:
ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
if ssl_max_protocol_version:
ssl.maximum_version = _parse_tls_version(
ssl_max_protocol_version
)
elif ssl is True:
ssl = ssl_module.create_default_context()
sslmode = SSLMode.verify_full
else:
sslmode = SSLMode.disable
if server_settings is not None and (
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
raise exceptions.ClientConfigurationError(
'server_settings is expected to be None or '
'a Dict[str, str]')
if target_session_attrs is None:
target_session_attrs = os.getenv(
"PGTARGETSESSIONATTRS", SessionAttribute.any
)
try:
target_session_attrs = SessionAttribute(target_session_attrs)
except ValueError:
raise exceptions.ClientConfigurationError(
"target_session_attrs is expected to be one of "
"{!r}"
", got {!r}".format(
SessionAttribute.__members__.values, target_session_attrs
)
) from None
if krbsrvname is None:
krbsrvname = os.getenv('PGKRBSRVNAME')
if gsslib is None:
gsslib = os.getenv('PGGSSLIB')
if gsslib is None:
gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
if gsslib not in {'gssapi', 'sspi'}:
raise exceptions.ClientConfigurationError(
"gsslib parameter must be either 'gssapi' or 'sspi'"
", got {!r}".format(gsslib))
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, ssl_negotiation=sslneg,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib)
return addrs, params
def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
database, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib,
service, servicefile):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
'statement_cache_size'}:
var_val = local_vars[var_name]
if var_val is None or isinstance(var_val, bool) or var_val < 0:
raise ValueError(
'{} is expected to be greater '
'or equal to 0, got {!r}'.format(var_name, var_val))
if command_timeout is not None:
try:
if isinstance(command_timeout, bool):
raise ValueError
command_timeout = float(command_timeout)
if command_timeout <= 0:
raise ValueError
except ValueError:
raise ValueError(
'invalid command_timeout value: '
'expected greater than 0 float (got {!r})'.format(
command_timeout)) from None
addrs, params = _parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname, gsslib=gsslib,
service=service, servicefile=servicefile)
config = _ClientConfiguration(
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,)
return addrs, params, config
class TLSUpgradeProto(asyncio.Protocol):
def __init__(
self,
loop: asyncio.AbstractEventLoop,
host: str,
port: int,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: bool,
) -> None:
self.on_data = _create_future(loop)
self.host = host
self.port = port
self.ssl_context = ssl_context
self.ssl_is_advisory = ssl_is_advisory
def data_received(self, data: bytes) -> None:
if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
self.ssl_context.verify_mode == ssl_module.CERT_NONE and
data == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from
# sslmode=prefer. But be extra sure to disallow insecure
# connections when the ssl context asks for real security.
self.on_data.set_result(False)
else:
self.on_data.set_exception(
ConnectionError(
'PostgreSQL server at "{host}:{port}" '
'rejected SSL upgrade'.format(
host=self.host, port=self.port)))
def connection_lost(self, exc: typing.Optional[Exception]) -> None:
if not self.on_data.done():
if exc is None:
exc = ConnectionError('unexpected connection_lost() call')
self.on_data.set_exception(exc)
_ProctolFactoryR = typing.TypeVar(
"_ProctolFactoryR", bound=asyncio.protocols.Protocol
)
async def _create_ssl_connection(
# TODO: The return type is a specific combination of subclasses of
# asyncio.protocols.Protocol that we can't express. For now, having the
# return type be dependent on signature of the factory is an improvement
protocol_factory: Callable[[], _ProctolFactoryR],
host: str,
port: int,
*,
loop: asyncio.AbstractEventLoop,
ssl_context: ssl_module.SSLContext,
ssl_is_advisory: bool = False,
) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:
tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
ssl_context, ssl_is_advisory),
host, port)
tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
try:
do_ssl_upgrade = await pr.on_data
except (Exception, asyncio.CancelledError):
tr.close()
raise
if hasattr(loop, 'start_tls'):
if do_ssl_upgrade:
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
assert new_tr is not None
except (Exception, asyncio.CancelledError):
tr.close()
raise
else:
new_tr = tr
pg_proto = protocol_factory()
pg_proto.is_ssl = do_ssl_upgrade
pg_proto.connection_made(new_tr)
new_tr.set_protocol(pg_proto)
return new_tr, pg_proto
else:
conn_factory = functools.partial(
loop.create_connection, protocol_factory)
if do_ssl_upgrade:
conn_factory = functools.partial(
conn_factory, ssl=ssl_context, server_hostname=host)
sock = _get_socket(tr)
sock = sock.dup()
_set_nodelay(sock)
tr.close()
try:
new_tr, pg_proto = await conn_factory(sock=sock)
pg_proto.is_ssl = do_ssl_upgrade
return new_tr, pg_proto
except (Exception, asyncio.CancelledError):
sock.close()
raise
async def _connect_addr(
*,
addr,
loop,
params,
config,
connection_class,
record_class
):
assert loop is not None
params_input = params
if callable(params.password):
password = params.password()
if inspect.isawaitable(password):
password = await password
params = params._replace(password=password)
args = (addr, loop, config, connection_class, record_class, params_input)
# prepare the params (which attempt has ssl) for the 2 attempts
if params.sslmode == SSLMode.allow:
params_retry = params
params = params._replace(ssl=None)
elif params.sslmode == SSLMode.prefer:
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, False, *args)
# first attempt
try:
return await __connect_addr(params, True, *args)
except _RetryConnectSignal:
pass
# second attempt
return await __connect_addr(params_retry, False, *args)
class _RetryConnectSignal(Exception):
pass
async def __connect_addr(
params,
retry,
addr,
loop,
config,
connection_class,
record_class,
params_input,
):
connected = _create_future(loop)
proto_factory = lambda: protocol.Protocol(
addr, connected, params, record_class, loop)
if isinstance(addr, str):
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)
elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
# if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
# direct SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
connector = loop.create_connection(proto_factory, *addr)
tr, pr = await connector
try:
await connected
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
):
tr.close()
# retry=True here is a redundant check because we don't want to
# accidentally raise the internal _RetryConnectSignal to the user
if retry and (
params.sslmode == SSLMode.allow and not pr.is_ssl or
params.sslmode == SSLMode.prefer and pr.is_ssl
):
# Trigger retry when:
# 1. First attempt with sslmode=allow, ssl=None failed
# 2. First attempt with sslmode=prefer, ssl=ctx failed while the
# server claimed to support SSL (returning "S" for SSLRequest)
# (likely because pg_hba.conf rejected the connection)
raise _RetryConnectSignal()
else:
# but will NOT retry if:
# 1. First attempt with sslmode=prefer failed but the server
# doesn't support SSL (returning 'N' for SSLRequest), because
# we already tried to connect without SSL thru ssl_is_advisory
# 2. Second attempt with sslmode=prefer, ssl=None failed
# 3. Second attempt with sslmode=allow, ssl=ctx failed
# 4. Any other sslmode
raise
except (Exception, asyncio.CancelledError):
tr.close()
raise
con = connection_class(pr, tr, loop, addr, config, params_input)
pr.set_connection(con)
return con
class SessionAttribute(str, enum.Enum):
any = 'any'
primary = 'primary'
standby = 'standby'
prefer_standby = 'prefer-standby'
read_write = "read-write"
read_only = "read-only"
def _accept_in_hot_standby(should_be_in_hot_standby: bool):
"""
If the server didn't report "in_hot_standby" at startup, we must determine
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
If the server allows a connection and states it is in recovery it must
be a replica/standby server.
"""
async def can_be_used(connection):
settings = connection.get_settings()
hot_standby_status = getattr(settings, 'in_hot_standby', None)
if hot_standby_status is not None:
is_in_hot_standby = hot_standby_status == 'on'
else:
is_in_hot_standby = await connection.fetchval(
"SELECT pg_catalog.pg_is_in_recovery()"
)
return is_in_hot_standby == should_be_in_hot_standby
return can_be_used
def _accept_read_only(should_be_read_only: bool):
"""
Verify the server has not set default_transaction_read_only=True
"""
async def can_be_used(connection):
settings = connection.get_settings()
is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
if is_readonly == "on":
return should_be_read_only
return await _accept_in_hot_standby(should_be_read_only)(connection)
return can_be_used
async def _accept_any(_):
return True
target_attrs_check = {
SessionAttribute.any: _accept_any,
SessionAttribute.primary: _accept_in_hot_standby(False),
SessionAttribute.standby: _accept_in_hot_standby(True),
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
SessionAttribute.read_write: _accept_read_only(False),
SessionAttribute.read_only: _accept_read_only(True),
}
async def _can_use_connection(connection, attr: SessionAttribute):
can_use = target_attrs_check[attr]
return await can_use(connection)
async def _connect(*, loop, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
addrs, params, config = _parse_connect_arguments(**kwargs)
target_attr = params.target_session_attrs
candidates = []
chosen_connection = None
last_error = None
try:
for addr in addrs:
try:
conn = await _connect_addr(
addr=addr,
loop=loop,
params=params,
config=config,
connection_class=connection_class,
record_class=record_class,
)
candidates.append(conn)
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except OSError as ex:
last_error = ex
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)
finally:
async def _close_candidates(conns, chosen):
await asyncio.gather(
*(c.close() for c in conns if c is not chosen),
return_exceptions=True
)
if candidates:
asyncio.create_task(
_close_candidates(candidates, chosen_connection))
if chosen_connection:
return chosen_connection
raise last_error or exceptions.TargetServerAttributeNotMatched(
'None of the hosts match the target attribute requirement '
'{!r}'.format(target_attr)
)
async def _cancel(*, loop, addr, params: _ConnectionParameters,
backend_pid, backend_secret):
class CancelProto(asyncio.Protocol):
def __init__(self):
self.on_disconnect = _create_future(loop)
self.is_ssl = False
def connection_lost(self, exc):
if not self.on_disconnect.done():
self.on_disconnect.set_result(True)
if isinstance(addr, str):
tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
if params.ssl and params.sslmode != SSLMode.allow:
tr, pr = await _create_ssl_connection(
CancelProto,
*addr,
loop=loop,
ssl_context=params.ssl,
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
tr, pr = await loop.create_connection(
CancelProto, *addr)
_set_nodelay(_get_socket(tr))
# Pack a CancelRequest message
msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
try:
tr.write(msg)
await pr.on_disconnect
finally:
tr.close()
def _get_socket(transport):
sock = transport.get_extra_info('socket')
if sock is None:
# Shouldn't happen with any asyncio-complaint event loop.
raise ConnectionError(
'could not get the socket for transport {!r}'.format(transport))
return sock
def _set_nodelay(sock):
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
def _create_future(loop):
try:
create_future = loop.create_future
except AttributeError:
return asyncio.Future(loop=loop)
else:
return create_future()