# Copyright (C) 2016-present the asyncpg authors and contributors # # # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 from __future__ import annotations import asyncio import configparser import collections from collections.abc import Callable import enum import functools import getpass import os import pathlib import platform import random import re import socket import ssl as ssl_module import stat import struct import sys import typing import urllib.parse import warnings import inspect from . import compat from . import exceptions from . import protocol class SSLMode(enum.IntEnum): disable = 0 allow = 1 prefer = 2 require = 3 verify_ca = 4 verify_full = 5 @classmethod def parse(cls, sslmode): if isinstance(sslmode, cls): return sslmode return getattr(cls, sslmode.replace('-', '_')) class SSLNegotiation(compat.StrEnum): postgres = "postgres" direct = "direct" _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ 'user', 'password', 'database', 'ssl', 'sslmode', 'ssl_negotiation', 'server_settings', 'target_session_attrs', 'krbsrvname', 'gsslib', ]) _ClientConfiguration = collections.namedtuple( 'ConnectionConfiguration', [ 'command_timeout', 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', ]) _system = platform.uname().system if _system == 'Windows': PGPASSFILE = 'pgpass.conf' else: PGPASSFILE = '.pgpass' PG_SERVICEFILE = '.pg_service.conf' def _read_password_file(passfile: pathlib.Path) \ -> typing.List[typing.Tuple[str, ...]]: passtab = [] try: if not passfile.exists(): return [] if not passfile.is_file(): warnings.warn( 'password file {!r} is not a plain file'.format(passfile)) return [] if _system != 'Windows': if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO): warnings.warn( 'password file {!r} has group or world access; ' 'permissions should be u=rw (0600) or less'.format( passfile)) return [] with passfile.open('rt') as f: for line in f: line = line.strip() if not line or line.startswith('#'): # Skip empty lines and comments. continue # Backslash escapes both itself and the colon, # which is a record separator. line = line.replace(R'\\', '\n') passtab.append(tuple( p.replace('\n', R'\\') for p in re.split(r'(? 1: # If there is a list of ports, its length must # match that of the host list. if len(port) != len(hosts): raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) elif isinstance(port, list) and len(port) == 1: port = [port[0] for _ in range(len(hosts))] else: port = [port for _ in range(len(hosts))] return port def _parse_hostlist(hostlist, port, *, unquote=False): if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] hosts = [] hostlist_ports = [] if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: default_port = [int(p) for p in portspec.split(',')] else: default_port = int(portspec) else: default_port = 5432 default_port = _validate_port_spec(hostspecs, default_port) else: port = _validate_port_spec(hostspecs, port) for i, hostspec in enumerate(hostspecs): if hostspec[0] == '/': # Unix socket addr = hostspec hostspec_port = '' elif hostspec[0] == '[': # IPv6 address m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec) if m: addr = m.group(1) hostspec_port = m.group(2) else: raise exceptions.ClientConfigurationError( 'invalid IPv6 address in the connection URI: {!r}'.format( hostspec ) ) else: # IPv4 address addr, _, hostspec_port = hostspec.partition(':') if unquote: addr = urllib.parse.unquote(addr) hosts.append(addr) if not port: if hostspec_port: if unquote: hostspec_port = urllib.parse.unquote(hostspec_port) hostlist_ports.append(int(hostspec_port)) else: hostlist_ports.append(default_port[i]) if not port: port = hostlist_ports return hosts, port def _parse_tls_version(tls_version): if tls_version.startswith('SSL'): raise exceptions.ClientConfigurationError( f"Unsupported TLS version: {tls_version}" ) try: return ssl_module.TLSVersion[tls_version.replace('.', '_')] except KeyError: raise exceptions.ClientConfigurationError( f"No such TLS version: {tls_version}" ) def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: try: homedir = pathlib.Path.home() except (RuntimeError, KeyError): return None return (homedir / '.postgresql' / filename).resolve() def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, service, servicefile, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None sslnegotiation = None if dsn: parsed = urllib.parse.urlparse(dsn) query = None if parsed.query: query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) for key, val in query.items(): if isinstance(val, list): query[key] = val[-1] if 'service' in query: val = query.pop('service') if not service and val: service = val connection_service_file = servicefile if connection_service_file is None: connection_service_file = os.getenv('PGSERVICEFILE') if connection_service_file is None: homedir = compat.get_pg_home_directory() if homedir: connection_service_file = homedir / PG_SERVICEFILE else: connection_service_file = None else: connection_service_file = pathlib.Path(connection_service_file) if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) if parsed.netloc: if '@' in parsed.netloc: dsn_auth, _, dsn_hostspec = parsed.netloc.partition('@') else: dsn_hostspec = parsed.netloc dsn_auth = '' else: dsn_auth = dsn_hostspec = '' if dsn_auth: dsn_user, _, dsn_password = dsn_auth.partition(':') else: dsn_user = dsn_password = '' if not host and dsn_hostspec: host, port = _parse_hostlist(dsn_hostspec, port, unquote=True) if parsed.path and database is None: dsn_database = parsed.path if dsn_database.startswith('/'): dsn_database = dsn_database[1:] database = urllib.parse.unquote(dsn_database) if user is None and dsn_user: user = urllib.parse.unquote(dsn_user) if password is None and dsn_password: password = urllib.parse.unquote(dsn_password) if query: if 'port' in query: val = query.pop('port') if not port and val: port = [int(p) for p in val.split(',')] if 'host' in query: val = query.pop('host') if not host and val: host, port = _parse_hostlist(val, port) if 'dbname' in query: val = query.pop('dbname') if database is None: database = val if 'database' in query: val = query.pop('database') if database is None: database = val if 'user' in query: val = query.pop('user') if user is None: user = val if 'password' in query: val = query.pop('password') if password is None: password = val if 'passfile' in query: val = query.pop('passfile') if passfile is None: passfile = val if 'sslmode' in query: val = query.pop('sslmode') if ssl is None: ssl = val if 'sslcert' in query: sslcert = query.pop('sslcert') if 'sslkey' in query: sslkey = query.pop('sslkey') if 'sslrootcert' in query: sslrootcert = query.pop('sslrootcert') if 'sslnegotiation' in query: sslnegotiation = query.pop('sslnegotiation') if 'sslcrl' in query: sslcrl = query.pop('sslcrl') if 'sslpassword' in query: sslpassword = query.pop('sslpassword') if 'ssl_min_protocol_version' in query: ssl_min_protocol_version = query.pop( 'ssl_min_protocol_version' ) if 'ssl_max_protocol_version' in query: ssl_max_protocol_version = query.pop( 'ssl_max_protocol_version' ) if 'target_session_attrs' in query: dsn_target_session_attrs = query.pop( 'target_session_attrs' ) if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs if 'krbsrvname' in query: val = query.pop('krbsrvname') if krbsrvname is None: krbsrvname = val if 'gsslib' in query: val = query.pop('gsslib') if gsslib is None: gsslib = val if 'service' in query: val = query.pop('service') if service is None: service = val if query: if server_settings is None: server_settings = query else: server_settings = {**query, **server_settings} if connection_service_file is not None and service is not None: pg_service = configparser.ConfigParser() pg_service.read(connection_service_file) if service in pg_service.sections(): service_params = pg_service[service] if 'port' in service_params: val = service_params.pop('port') if not port and val: port = [int(p) for p in val.split(',')] if 'host' in service_params: val = service_params.pop('host') if not host and val: host, port = _parse_hostlist(val, port) if 'dbname' in service_params: val = service_params.pop('dbname') if database is None: database = val if 'database' in service_params: val = service_params.pop('database') if database is None: database = val if 'user' in service_params: val = service_params.pop('user') if user is None: user = val if 'password' in service_params: val = service_params.pop('password') if password is None: password = val if 'passfile' in service_params: val = service_params.pop('passfile') if passfile is None: passfile = val if 'sslmode' in service_params: val = service_params.pop('sslmode') if ssl is None: ssl = val if 'sslcert' in service_params: val = service_params.pop('sslcert') if sslcert is None: sslcert = val if 'sslkey' in service_params: val = service_params.pop('sslkey') if sslkey is None: sslkey = val if 'sslrootcert' in service_params: val = service_params.pop('sslrootcert') if sslrootcert is None: sslrootcert = val if 'sslnegotiation' in service_params: val = service_params.pop('sslnegotiation') if sslnegotiation is None: sslnegotiation = val if 'sslcrl' in service_params: val = service_params.pop('sslcrl') if sslcrl is None: sslcrl = val if 'sslpassword' in service_params: val = service_params.pop('sslpassword') if sslpassword is None: sslpassword = val if 'ssl_min_protocol_version' in service_params: val = service_params.pop( 'ssl_min_protocol_version' ) if ssl_min_protocol_version is None: ssl_min_protocol_version = val if 'ssl_max_protocol_version' in service_params: val = service_params.pop( 'ssl_max_protocol_version' ) if ssl_max_protocol_version is None: ssl_max_protocol_version = val if 'target_session_attrs' in service_params: dsn_target_session_attrs = service_params.pop( 'target_session_attrs' ) if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs if 'krbsrvname' in service_params: val = service_params.pop('krbsrvname') if krbsrvname is None: krbsrvname = val if 'gsslib' in service_params: val = service_params.pop('gsslib') if gsslib is None: gsslib = val if not service: service = os.environ.get('PGSERVICE') if not host: hostspec = os.environ.get('PGHOST') if hostspec: host, port = _parse_hostlist(hostspec, port) if not host: auth_hosts = ['localhost'] if _system == 'Windows': host = ['localhost'] else: host = ['/run/postgresql', '/var/run/postgresql', '/tmp', '/private/tmp', 'localhost'] if not isinstance(host, (list, tuple)): host = [host] if auth_hosts is None: auth_hosts = host if not port: portspec = os.environ.get('PGPORT') if portspec: if ',' in portspec: port = [int(p) for p in portspec.split(',')] else: port = int(portspec) else: port = 5432 elif isinstance(port, (list, tuple)): port = [int(p) for p in port] else: port = int(port) port = _validate_port_spec(host, port) if user is None: user = os.getenv('PGUSER') if not user: user = getpass.getuser() if password is None: password = os.getenv('PGPASSWORD') if database is None: database = os.getenv('PGDATABASE') if database is None: database = user if user is None: raise exceptions.ClientConfigurationError( 'could not determine user name to connect with') if database is None: raise exceptions.ClientConfigurationError( 'could not determine database name to connect to') if password is None: if passfile is None: passfile = os.getenv('PGPASSFILE') if passfile is None: homedir = compat.get_pg_home_directory() if homedir: passfile = homedir / PGPASSFILE else: passfile = None else: passfile = pathlib.Path(passfile) if passfile is not None: password = _read_password_from_pgpass( hosts=auth_hosts, ports=port, database=database, user=user, passfile=passfile) addrs = [] have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: h = os.path.join(h, '.s.PGSQL.{}'.format(p)) addrs.append(h) else: # TCP host/port addrs.append((h, p)) have_tcp_addrs = True if not addrs: raise exceptions.InternalClientError( 'could not determine the database address to connect to') if ssl is None: ssl = os.getenv('PGSSLMODE') if ssl is None and have_tcp_addrs: ssl = 'prefer' if direct_tls is not None: sslneg = ( SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres ) else: if sslnegotiation is None: sslnegotiation = os.environ.get("PGSSLNEGOTIATION") if sslnegotiation is not None: try: sslneg = SSLNegotiation(sslnegotiation) except ValueError: modes = ', '.join( m.name.replace('_', '-') for m in SSLNegotiation ) raise exceptions.ClientConfigurationError( f'`sslnegotiation` parameter must be one of: {modes}' ) from None else: sslneg = SSLNegotiation.postgres if isinstance(ssl, (str, SSLMode)): try: sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( '`sslmode` parameter must be one of: {}'.format(modes) ) from None # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: ssl = False else: ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT) ssl.check_hostname = sslmode >= SSLMode.verify_full if sslmode < SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: if sslrootcert is None: sslrootcert = os.getenv('PGSSLROOTCERT') if sslrootcert: ssl.load_verify_locations(cafile=sslrootcert) ssl.verify_mode = ssl_module.CERT_REQUIRED else: try: sslrootcert = _dot_postgresql_path('root.crt') if sslrootcert is not None: ssl.load_verify_locations(cafile=sslrootcert) else: raise exceptions.ClientConfigurationError( 'cannot determine location of user ' 'PostgreSQL configuration directory' ) except ( exceptions.ClientConfigurationError, FileNotFoundError, NotADirectoryError, ): if sslmode > SSLMode.require: if sslrootcert is None: sslrootcert = '~/.postgresql/root.crt' detail = ( 'Could not determine location of user ' 'home directory (HOME is either unset, ' 'inaccessible, or does not point to a ' 'valid directory)' ) else: detail = None raise exceptions.ClientConfigurationError( f'root certificate file "{sslrootcert}" does ' f'not exist or cannot be accessed', hint='Provide the certificate file directly ' f'or make sure "{sslrootcert}" ' 'exists and is readable.', detail=detail, ) elif sslmode == SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE else: assert False, 'unreachable' else: ssl.verify_mode = ssl_module.CERT_REQUIRED if sslcrl is None: sslcrl = os.getenv('PGSSLCRL') if sslcrl: ssl.load_verify_locations(cafile=sslcrl) ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN else: sslcrl = _dot_postgresql_path('root.crl') if sslcrl is not None: try: ssl.load_verify_locations(cafile=sslcrl) except ( FileNotFoundError, NotADirectoryError, ): pass else: ssl.verify_flags |= \ ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: sslkey = os.getenv('PGSSLKEY') if not sslkey: sslkey = _dot_postgresql_path('postgresql.key') if sslkey is not None and not sslkey.exists(): sslkey = None if not sslpassword: sslpassword = '' if sslcert is None: sslcert = os.getenv('PGSSLCERT') if sslcert: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) else: sslcert = _dot_postgresql_path('postgresql.crt') if sslcert is not None: try: ssl.load_cert_chain( sslcert, keyfile=sslkey, password=lambda: sslpassword ) except (FileNotFoundError, NotADirectoryError): pass # OpenSSL 1.1.1 keylog file, copied from create_default_context() if hasattr(ssl, 'keylog_filename'): keylogfile = os.environ.get('SSLKEYLOGFILE') if keylogfile and not sys.flags.ignore_environment: ssl.keylog_filename = keylogfile if ssl_min_protocol_version is None: ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION') if ssl_min_protocol_version: ssl.minimum_version = _parse_tls_version( ssl_min_protocol_version ) else: ssl.minimum_version = _parse_tls_version('TLSv1.2') if ssl_max_protocol_version is None: ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION') if ssl_max_protocol_version: ssl.maximum_version = _parse_tls_version( ssl_max_protocol_version ) elif ssl is True: ssl = ssl_module.create_default_context() sslmode = SSLMode.verify_full else: sslmode = SSLMode.disable if server_settings is not None and ( not isinstance(server_settings, dict) or not all(isinstance(k, str) for k in server_settings) or not all(isinstance(v, str) for v in server_settings.values())): raise exceptions.ClientConfigurationError( 'server_settings is expected to be None or ' 'a Dict[str, str]') if target_session_attrs is None: target_session_attrs = os.getenv( "PGTARGETSESSIONATTRS", SessionAttribute.any ) try: target_session_attrs = SessionAttribute(target_session_attrs) except ValueError: raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( SessionAttribute.__members__.values, target_session_attrs ) ) from None if krbsrvname is None: krbsrvname = os.getenv('PGKRBSRVNAME') if gsslib is None: gsslib = os.getenv('PGGSSLIB') if gsslib is None: gsslib = 'sspi' if _system == 'Windows' else 'gssapi' if gsslib not in {'gssapi', 'sspi'}: raise exceptions.ClientConfigurationError( "gsslib parameter must be either 'gssapi' or 'sspi'" ", got {!r}".format(gsslib)) params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, ssl_negotiation=sslneg, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib) return addrs, params def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, database, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib, service, servicefile): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', 'statement_cache_size'}: var_val = local_vars[var_name] if var_val is None or isinstance(var_val, bool) or var_val < 0: raise ValueError( '{} is expected to be greater ' 'or equal to 0, got {!r}'.format(var_name, var_val)) if command_timeout is not None: try: if isinstance(command_timeout, bool): raise ValueError command_timeout = float(command_timeout) if command_timeout <= 0: raise ValueError except ValueError: raise ValueError( 'invalid command_timeout value: ' 'expected greater than 0 float (got {!r})'.format( command_timeout)) from None addrs, params = _parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, service=service, servicefile=servicefile) config = _ClientConfiguration( command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size,) return addrs, params, config class TLSUpgradeProto(asyncio.Protocol): def __init__( self, loop: asyncio.AbstractEventLoop, host: str, port: int, ssl_context: ssl_module.SSLContext, ssl_is_advisory: bool, ) -> None: self.on_data = _create_future(loop) self.host = host self.port = port self.ssl_context = ssl_context self.ssl_is_advisory = ssl_is_advisory def data_received(self, data: bytes) -> None: if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and self.ssl_context.verify_mode == ssl_module.CERT_NONE and data == b'N'): # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, # since the only way to get ssl_is_advisory is from # sslmode=prefer. But be extra sure to disallow insecure # connections when the ssl context asks for real security. self.on_data.set_result(False) else: self.on_data.set_exception( ConnectionError( 'PostgreSQL server at "{host}:{port}" ' 'rejected SSL upgrade'.format( host=self.host, port=self.port))) def connection_lost(self, exc: typing.Optional[Exception]) -> None: if not self.on_data.done(): if exc is None: exc = ConnectionError('unexpected connection_lost() call') self.on_data.set_exception(exc) _ProctolFactoryR = typing.TypeVar( "_ProctolFactoryR", bound=asyncio.protocols.Protocol ) async def _create_ssl_connection( # TODO: The return type is a specific combination of subclasses of # asyncio.protocols.Protocol that we can't express. For now, having the # return type be dependent on signature of the factory is an improvement protocol_factory: Callable[[], _ProctolFactoryR], host: str, port: int, *, loop: asyncio.AbstractEventLoop, ssl_context: ssl_module.SSLContext, ssl_is_advisory: bool = False, ) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]: tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, ssl_context, ssl_is_advisory), host, port) tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message. try: do_ssl_upgrade = await pr.on_data except (Exception, asyncio.CancelledError): tr.close() raise if hasattr(loop, 'start_tls'): if do_ssl_upgrade: try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) assert new_tr is not None except (Exception, asyncio.CancelledError): tr.close() raise else: new_tr = tr pg_proto = protocol_factory() pg_proto.is_ssl = do_ssl_upgrade pg_proto.connection_made(new_tr) new_tr.set_protocol(pg_proto) return new_tr, pg_proto else: conn_factory = functools.partial( loop.create_connection, protocol_factory) if do_ssl_upgrade: conn_factory = functools.partial( conn_factory, ssl=ssl_context, server_hostname=host) sock = _get_socket(tr) sock = sock.dup() _set_nodelay(sock) tr.close() try: new_tr, pg_proto = await conn_factory(sock=sock) pg_proto.is_ssl = do_ssl_upgrade return new_tr, pg_proto except (Exception, asyncio.CancelledError): sock.close() raise async def _connect_addr( *, addr, loop, params, config, connection_class, record_class ): assert loop is not None params_input = params if callable(params.password): password = params.password() if inspect.isawaitable(password): password = await password params = params._replace(password=password) args = (addr, loop, config, connection_class, record_class, params_input) # prepare the params (which attempt has ssl) for the 2 attempts if params.sslmode == SSLMode.allow: params_retry = params params = params._replace(ssl=None) elif params.sslmode == SSLMode.prefer: params_retry = params._replace(ssl=None) else: # skip retry if we don't have to return await __connect_addr(params, False, *args) # first attempt try: return await __connect_addr(params, True, *args) except _RetryConnectSignal: pass # second attempt return await __connect_addr(params_retry, False, *args) class _RetryConnectSignal(Exception): pass async def __connect_addr( params, retry, addr, loop, config, connection_class, record_class, params_input, ): connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) if isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct: # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform # direct SSL connection connector = loop.create_connection( proto_factory, *addr, ssl=params.ssl ) elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: connector = loop.create_connection(proto_factory, *addr) tr, pr = await connector try: await connected except ( exceptions.InvalidAuthorizationSpecificationError, exceptions.ConnectionDoesNotExistError, # seen on Windows ): tr.close() # retry=True here is a redundant check because we don't want to # accidentally raise the internal _RetryConnectSignal to the user if retry and ( params.sslmode == SSLMode.allow and not pr.is_ssl or params.sslmode == SSLMode.prefer and pr.is_ssl ): # Trigger retry when: # 1. First attempt with sslmode=allow, ssl=None failed # 2. First attempt with sslmode=prefer, ssl=ctx failed while the # server claimed to support SSL (returning "S" for SSLRequest) # (likely because pg_hba.conf rejected the connection) raise _RetryConnectSignal() else: # but will NOT retry if: # 1. First attempt with sslmode=prefer failed but the server # doesn't support SSL (returning 'N' for SSLRequest), because # we already tried to connect without SSL thru ssl_is_advisory # 2. Second attempt with sslmode=prefer, ssl=None failed # 3. Second attempt with sslmode=allow, ssl=ctx failed # 4. Any other sslmode raise except (Exception, asyncio.CancelledError): tr.close() raise con = connection_class(pr, tr, loop, addr, config, params_input) pr.set_connection(con) return con class SessionAttribute(str, enum.Enum): any = 'any' primary = 'primary' standby = 'standby' prefer_standby = 'prefer-standby' read_write = "read-write" read_only = "read-only" def _accept_in_hot_standby(should_be_in_hot_standby: bool): """ If the server didn't report "in_hot_standby" at startup, we must determine the state by checking "SELECT pg_catalog.pg_is_in_recovery()". If the server allows a connection and states it is in recovery it must be a replica/standby server. """ async def can_be_used(connection): settings = connection.get_settings() hot_standby_status = getattr(settings, 'in_hot_standby', None) if hot_standby_status is not None: is_in_hot_standby = hot_standby_status == 'on' else: is_in_hot_standby = await connection.fetchval( "SELECT pg_catalog.pg_is_in_recovery()" ) return is_in_hot_standby == should_be_in_hot_standby return can_be_used def _accept_read_only(should_be_read_only: bool): """ Verify the server has not set default_transaction_read_only=True """ async def can_be_used(connection): settings = connection.get_settings() is_readonly = getattr(settings, 'default_transaction_read_only', 'off') if is_readonly == "on": return should_be_read_only return await _accept_in_hot_standby(should_be_read_only)(connection) return can_be_used async def _accept_any(_): return True target_attrs_check = { SessionAttribute.any: _accept_any, SessionAttribute.primary: _accept_in_hot_standby(False), SessionAttribute.standby: _accept_in_hot_standby(True), SessionAttribute.prefer_standby: _accept_in_hot_standby(True), SessionAttribute.read_write: _accept_read_only(False), SessionAttribute.read_only: _accept_read_only(True), } async def _can_use_connection(connection, attr: SessionAttribute): can_use = target_attrs_check[attr] return await can_use(connection) async def _connect(*, loop, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(**kwargs) target_attr = params.target_session_attrs candidates = [] chosen_connection = None last_error = None try: for addr in addrs: try: conn = await _connect_addr( addr=addr, loop=loop, params=params, config=config, connection_class=connection_class, record_class=record_class, ) candidates.append(conn) if await _can_use_connection(conn, target_attr): chosen_connection = conn break except OSError as ex: last_error = ex else: if target_attr == SessionAttribute.prefer_standby and candidates: chosen_connection = random.choice(candidates) finally: async def _close_candidates(conns, chosen): await asyncio.gather( *(c.close() for c in conns if c is not chosen), return_exceptions=True ) if candidates: asyncio.create_task( _close_candidates(candidates, chosen_connection)) if chosen_connection: return chosen_connection raise last_error or exceptions.TargetServerAttributeNotMatched( 'None of the hosts match the target attribute requirement ' '{!r}'.format(target_attr) ) async def _cancel(*, loop, addr, params: _ConnectionParameters, backend_pid, backend_secret): class CancelProto(asyncio.Protocol): def __init__(self): self.on_disconnect = _create_future(loop) self.is_ssl = False def connection_lost(self, exc): if not self.on_disconnect.done(): self.on_disconnect.set_result(True) if isinstance(addr, str): tr, pr = await loop.create_unix_connection(CancelProto, addr) else: if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( CancelProto, *addr, loop=loop, ssl_context=params.ssl, ssl_is_advisory=params.sslmode == SSLMode.prefer) else: tr, pr = await loop.create_connection( CancelProto, *addr) _set_nodelay(_get_socket(tr)) # Pack a CancelRequest message msg = struct.pack('!llll', 16, 80877102, backend_pid, backend_secret) try: tr.write(msg) await pr.on_disconnect finally: tr.close() def _get_socket(transport): sock = transport.get_extra_info('socket') if sock is None: # Shouldn't happen with any asyncio-complaint event loop. raise ConnectionError( 'could not get the socket for transport {!r}'.format(transport)) return sock def _set_nodelay(sock): if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) def _create_future(loop): try: create_future = loop.create_future except AttributeError: return asyncio.Future(loop=loop) else: return create_future()