# Python implementation of the MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/client-server-protocol.html import asyncio import os import socket import struct import sys import warnings import configparser import getpass from functools import partial from pymysql.charset import charset_by_name, charset_by_id from pymysql.constants import SERVER_STATUS from pymysql.constants import CLIENT from pymysql.constants import COMMAND from pymysql.constants import CR from pymysql.constants import FIELD_TYPE from pymysql.converters import (escape_item, encoders, decoders, escape_string, escape_bytes_prefixed, through) from pymysql.err import (Warning, Error, InterfaceError, DataError, DatabaseError, OperationalError, IntegrityError, InternalError, NotSupportedError, ProgrammingError) from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET from pymysql.connections import _auth from pymysql.connections import MysqlPacket from pymysql.connections import FieldDescriptorPacket from pymysql.connections import EOFPacketWrapper from pymysql.connections import OKPacketWrapper from pymysql.connections import LoadLocalPacketWrapper # from aiomysql.utils import _convert_to_str from .cursors import Cursor from .utils import _pack_int24, _lenenc_int, _ConnectionContextManager, _ContextManager from .log import logger try: DEFAULT_USER = getpass.getuser() except KeyError: DEFAULT_USER = "unknown" def connect(host="localhost", user=None, password="", db=None, port=3306, unix_socket=None, charset='', sql_mode=None, read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): """See connections.Connection.__init__() for information about defaults.""" coro = _connect(host=host, user=user, password=password, db=db, port=port, unix_socket=unix_socket, charset=charset, sql_mode=sql_mode, read_default_file=read_default_file, conv=conv, use_unicode=use_unicode, client_flag=client_flag, cursorclass=cursorclass, init_command=init_command, connect_timeout=connect_timeout, read_default_group=read_default_group, autocommit=autocommit, echo=echo, local_infile=local_infile, loop=loop, ssl=ssl, auth_plugin=auth_plugin, program_name=program_name) return _ConnectionContextManager(coro) async def _connect(*args, **kwargs): conn = Connection(*args, **kwargs) await conn._connect() return conn async def _open_connection(host=None, port=None, **kwds): """This is based on asyncio.open_connection, allowing us to use a custom StreamReader. `limit` arg has been removed as we don't currently use it. """ loop = asyncio.events.get_running_loop() reader = _StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) transport, _ = await loop.create_connection( lambda: protocol, host, port, **kwds) writer = asyncio.StreamWriter(transport, protocol, reader, loop) return reader, writer async def _open_unix_connection(path=None, **kwds): """This is based on asyncio.open_unix_connection, allowing us to use a custom StreamReader. `limit` arg has been removed as we don't currently use it. """ loop = asyncio.events.get_running_loop() reader = _StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) transport, _ = await loop.create_unix_connection( lambda: protocol, path, **kwds) writer = asyncio.StreamWriter(transport, protocol, reader, loop) return reader, writer class _StreamReader(asyncio.StreamReader): """This StreamReader exposes whether EOF was received, allowing us to discard the associated connection instead of returning it from the pool when checking free connections in Pool._fill_free_pool(). `limit` arg has been removed as we don't currently use it. """ def __init__(self, loop=None): self._eof_received = False super().__init__(loop=loop) def feed_eof(self) -> None: self._eof_received = True super().feed_eof() @property def eof_received(self): return self._eof_received class Connection: """Representation of a socket with a mysql server. The proper way to get an instance of this class is to call connect(). """ def __init__(self, host="localhost", user=None, password="", db=None, port=3306, unix_socket=None, charset='', sql_mode=None, read_default_file=None, conv=decoders, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, read_default_group=None, autocommit=False, echo=False, local_infile=False, loop=None, ssl=None, auth_plugin='', program_name='', server_public_key=None): """ Establish a connection to the MySQL database. Accepts several arguments: :param host: Host where the database server is located :param user: Username to log in as :param password: Password to use. :param db: Database to use, None to not use a particular one. :param port: MySQL port to use, default is usually OK. :param unix_socket: Optionally, you can use a unix socket rather than TCP/IP. :param charset: Charset you want to use. :param sql_mode: Default SQL_MODE to use. :param read_default_file: Specifies my.cnf file to read these parameters from under the [client] section. :param conv: Decoders dictionary to use instead of the default one. This is used to provide custom marshalling of types. See converters. :param use_unicode: Whether or not to default to unicode strings. :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT. :param cursorclass: Custom cursor class to use. :param init_command: Initial SQL statement to run when connection is established. :param connect_timeout: Timeout before throwing an exception when connecting. :param read_default_group: Group to read from in the configuration file. :param autocommit: Autocommit mode. None means use server default. (default: False) :param local_infile: boolean to enable the use of LOAD DATA LOCAL command. (default: False) :param ssl: Optional SSL Context to force SSL :param auth_plugin: String to manually specify the authentication plugin to use, i.e you will want to use mysql_clear_password when using IAM authentication with Amazon RDS. (default: Server Default) :param program_name: Program name string to provide when handshaking with MySQL. (omitted by default) :param server_public_key: SHA256 authentication plugin public key value. :param loop: asyncio loop """ self._loop = loop or asyncio.get_event_loop() if use_unicode is None and sys.version_info[0] > 2: use_unicode = True if read_default_file: if not read_default_group: read_default_group = "client" cfg = configparser.RawConfigParser() cfg.read(os.path.expanduser(read_default_file)) _config = partial(cfg.get, read_default_group) user = _config("user", fallback=user) password = _config("password", fallback=password) host = _config("host", fallback=host) db = _config("database", fallback=db) unix_socket = _config("socket", fallback=unix_socket) port = int(_config("port", fallback=port)) charset = _config("default-character-set", fallback=charset) self._host = host self._port = port self._user = user or DEFAULT_USER self._password = password or "" self._db = db self._echo = echo self._last_usage = self._loop.time() self._client_auth_plugin = auth_plugin self._server_auth_plugin = "" self._auth_plugin_used = "" self._secure = False self.server_public_key = server_public_key self.salt = None from . import __version__ self._connect_attrs = { '_client_name': 'aiomysql', '_pid': str(os.getpid()), '_client_version': __version__, } if program_name: self._connect_attrs["program_name"] = program_name self._unix_socket = unix_socket if charset: self._charset = charset self.use_unicode = True else: self._charset = DEFAULT_CHARSET self.use_unicode = False if use_unicode is not None: self.use_unicode = use_unicode self._ssl_context = ssl if ssl: client_flag |= CLIENT.SSL self._encoding = charset_by_name(self._charset).encoding if local_infile: client_flag |= CLIENT.LOCAL_FILES client_flag |= CLIENT.CAPABILITIES client_flag |= CLIENT.MULTI_STATEMENTS if self._db: client_flag |= CLIENT.CONNECT_WITH_DB self.client_flag = client_flag self.cursorclass = cursorclass self.connect_timeout = connect_timeout self._result = None self._affected_rows = 0 self.host_info = "Not connected" #: specified autocommit mode. None means use server default. self.autocommit_mode = autocommit self.encoders = encoders # Need for MySQLdb compatibility. self.decoders = conv self.sql_mode = sql_mode self.init_command = init_command # asyncio StreamReader, StreamWriter self._reader = None self._writer = None # If connection was closed for specific reason, we should show that to # user self._close_reason = None @property def host(self): """MySQL server IP address or name""" return self._host @property def port(self): """MySQL server TCP/IP port""" return self._port @property def unix_socket(self): """MySQL Unix socket file location""" return self._unix_socket @property def db(self): """Current database name.""" return self._db @property def user(self): """User used while connecting to MySQL""" return self._user @property def echo(self): """Return echo mode status.""" return self._echo @property def last_usage(self): """Return time() when connection was used.""" return self._last_usage @property def loop(self): return self._loop @property def closed(self): """The readonly property that returns ``True`` if connections is closed. """ return self._writer is None @property def encoding(self): """Encoding employed for this connection.""" return self._encoding @property def charset(self): """Returns the character set for current connection.""" return self._charset def close(self): """Close socket connection""" if self._writer: self._writer.transport.close() self._writer = None self._reader = None async def ensure_closed(self): """Send quit command and then close socket connection""" if self._writer is None: # connection has been closed return send_data = struct.pack('= 5: self.client_flag |= CLIENT.MULTI_RESULTS if self.user is None: raise ValueError("Did not specify a username") charset_id = charset_by_name(self.charset).id data_init = struct.pack('=5.0) data += authresp + b'\0' if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: if isinstance(self._db, str): db = self._db.encode(self.encoding) else: db = self._db data += db + b'\0' if self.server_capabilities & CLIENT.PLUGIN_AUTH: name = auth_plugin if isinstance(name, str): name = name.encode('ascii') data += name + b'\0' self._auth_plugin_used = auth_plugin # Sends the server a few pieces of client info if self.server_capabilities & CLIENT.CONNECT_ATTRS: connect_attrs = b'' for k, v in self._connect_attrs.items(): k, v = k.encode('utf8'), v.encode('utf8') connect_attrs += struct.pack('B', len(k)) + k connect_attrs += struct.pack('B', len(v)) + v data += struct.pack('B', len(connect_attrs)) + connect_attrs self.write_packet(data) auth_packet = await self._read_packet() # if authentication method isn't accepted the first byte # will have the octet 254 if auth_packet.is_auth_switch_request(): # https://dev.mysql.com/doc/internals/en/ # connection-phase-packets.html#packet-Protocol::AuthSwitchRequest auth_packet.read_uint8() # 0xfe packet identifier plugin_name = auth_packet.read_string() if (self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None): await self._process_auth(plugin_name, auth_packet) else: # send legacy handshake data = _auth.scramble_old_password( self._password.encode('latin1'), auth_packet.read_all()) + b'\0' self.write_packet(data) await self._read_packet() elif auth_packet.is_extra_auth_data(): if auth_plugin == "caching_sha2_password": await self.caching_sha2_password_auth(auth_packet) elif auth_plugin == "sha256_password": await self.sha256_password_auth(auth_packet) else: raise OperationalError("Received extra packet " "for auth method %r", auth_plugin) async def _process_auth(self, plugin_name, auth_packet): # These auth plugins do their own packet handling if plugin_name == b"caching_sha2_password": await self.caching_sha2_password_auth(auth_packet) self._auth_plugin_used = plugin_name.decode() elif plugin_name == b"sha256_password": await self.sha256_password_auth(auth_packet) self._auth_plugin_used = plugin_name.decode() else: if plugin_name == b"mysql_native_password": # https://dev.mysql.com/doc/internals/en/ # secure-password-authentication.html#packet-Authentication:: # Native41 data = _auth.scramble_native_password( self._password.encode('latin1'), auth_packet.read_all()) elif plugin_name == b"mysql_old_password": # https://dev.mysql.com/doc/internals/en/ # old-password-authentication.html data = _auth.scramble_old_password( self._password.encode('latin1'), auth_packet.read_all() ) + b'\0' elif plugin_name == b"mysql_clear_password": # https://dev.mysql.com/doc/internals/en/ # clear-text-authentication.html data = self._password.encode('latin1') + b'\0' else: raise OperationalError( 2059, "Authentication plugin '{}'" " not configured".format(plugin_name) ) self.write_packet(data) pkt = await self._read_packet() pkt.check_error() self._auth_plugin_used = plugin_name.decode() return pkt async def caching_sha2_password_auth(self, pkt): # No password fast path if not self._password: self.write_packet(b'') pkt = await self._read_packet() pkt.check_error() return pkt if pkt.is_auth_switch_request(): # Try from fast auth logger.debug("caching sha2: Trying fast path") self.salt = pkt.read_all() scrambled = _auth.scramble_caching_sha2( self._password.encode('latin1'), self.salt ) self.write_packet(scrambled) pkt = await self._read_packet() pkt.check_error() # else: fast auth is tried in initial handshake if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet " "for fast auth: {}".format(pkt._data[:1]) ) # magic numbers: # 2 - request public key # 3 - fast auth succeeded # 4 - need full auth pkt.advance(1) n = pkt.read_uint8() if n == 3: logger.debug("caching sha2: succeeded by fast path.") pkt = await self._read_packet() pkt.check_error() # pkt must be OK packet return pkt if n != 4: raise OperationalError("caching sha2: Unknown " "result for fast auth: {}".format(n)) logger.debug("caching sha2: Trying full auth...") if self._secure: logger.debug("caching sha2: Sending plain " "password via secure connection") self.write_packet(self._password.encode('latin1') + b'\0') pkt = await self._read_packet() pkt.check_error() return pkt if not self.server_public_key: self.write_packet(b'\x02') pkt = await self._read_packet() # Request public key pkt.check_error() if not pkt.is_extra_auth_data(): raise OperationalError( "caching sha2: Unknown packet " "for public key: {}".format(pkt._data[:1]) ) self.server_public_key = pkt._data[1:] logger.debug(self.server_public_key.decode('ascii')) data = _auth.sha2_rsa_encrypt( self._password.encode('latin1'), self.salt, self.server_public_key ) self.write_packet(data) pkt = await self._read_packet() pkt.check_error() async def sha256_password_auth(self, pkt): if self._secure: logger.debug("sha256: Sending plain password") data = self._password.encode('latin1') + b'\0' self.write_packet(data) pkt = await self._read_packet() pkt.check_error() return pkt if pkt.is_auth_switch_request(): self.salt = pkt.read_all() if not self.server_public_key and self._password: # Request server public key logger.debug("sha256: Requesting server public key") self.write_packet(b'\1') pkt = await self._read_packet() pkt.check_error() if pkt.is_extra_auth_data(): self.server_public_key = pkt._data[1:] logger.debug( "Received public key:\n%s", self.server_public_key.decode('ascii') ) if self._password: if not self.server_public_key: raise OperationalError("Couldn't receive server's public key") data = _auth.sha2_rsa_encrypt( self._password.encode('latin1'), self.salt, self.server_public_key ) else: data = b'' self.write_packet(data) pkt = await self._read_packet() pkt.check_error() return pkt # _mysql support def thread_id(self): return self.server_thread_id[0] def character_set_name(self): return self._charset def get_host_info(self): return self.host_info def get_proto_info(self): return self.protocol_version async def _get_server_information(self): i = 0 packet = await self._read_packet() data = packet.get_all_data() # logger.debug(dump_packet(data)) self.protocol_version = data[i] i += 1 server_end = data.find(b'\0', i) self.server_version = data[i:server_end].decode('latin1') i = server_end + 1 self.server_thread_id = struct.unpack('= i + 6: lang, stat, cap_h, salt_len = struct.unpack('= i + salt_len: # salt_len includes auth_plugin_data_part_1 and filler self.salt += data[i:i + salt_len] i += salt_len i += 1 # AUTH PLUGIN NAME may appear here. if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: # Due to Bug#59453 the auth-plugin-name is missing the terminating # NUL-char in versions prior to 5.5.10 and 5.6.2. # ref: https://dev.mysql.com/doc/internals/en/ # connection-phase-packets.html#packet-Protocol::Handshake # didn't use version checks as mariadb is corrected and reports # earlier than those two. server_end = data.find(b'\0', i) if server_end < 0: # pragma: no cover - very specific upstream bug # not found \0 and last field so take it all self._server_auth_plugin = data[i:].decode('latin1') else: self._server_auth_plugin = data[i:server_end].decode('latin1') def get_transaction_status(self): return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS) def get_server_info(self): return self.server_version # Just to always have consistent errors 2 helpers def _close_on_cancel(self): self.close() self._close_reason = "Cancelled during execution" def _ensure_alive(self): if not self._writer: if self._close_reason is None: raise InterfaceError("(0, 'Not connected')") else: raise InterfaceError(self._close_reason) def __del__(self): if self._writer: warnings.warn(f"Unclosed connection {self!r}", ResourceWarning) self.close() Warning = Warning Error = Error InterfaceError = InterfaceError DatabaseError = DatabaseError DataError = DataError OperationalError = OperationalError IntegrityError = IntegrityError InternalError = InternalError ProgrammingError = ProgrammingError NotSupportedError = NotSupportedError # TODO: move OK and EOF packet parsing/logic into a proper subclass # of MysqlPacket like has been done with FieldDescriptorPacket. class MySQLResult: def __init__(self, connection): self.connection = connection self.affected_rows = None self.insert_id = None self.server_status = None self.warning_count = 0 self.message = None self.field_count = 0 self.description = None self.rows = None self.has_next = None self.unbuffered_active = False async def read(self): try: first_packet = await self.connection._read_packet() # TODO: use classes for different packet types? if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) elif first_packet.is_load_local_packet(): await self._read_load_local_packet(first_packet) else: await self._read_result_packet(first_packet) finally: self.connection = None async def init_unbuffered_query(self): self.unbuffered_active = True first_packet = await self.connection._read_packet() if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) self.unbuffered_active = False self.connection = None elif first_packet.is_load_local_packet(): await self._read_load_local_packet(first_packet) self.unbuffered_active = False self.connection = None else: self.field_count = first_packet.read_length_encoded_integer() await self._get_descriptions() # Apparently, MySQLdb picks this number because it's the maximum # value of a 64bit unsigned integer. Since we're emulating MySQLdb, # we set it to this instead of None, which would be preferred. self.affected_rows = 18446744073709551615 def _read_ok_packet(self, first_packet): ok_packet = OKPacketWrapper(first_packet) self.affected_rows = ok_packet.affected_rows self.insert_id = ok_packet.insert_id self.server_status = ok_packet.server_status self.warning_count = ok_packet.warning_count self.message = ok_packet.message self.has_next = ok_packet.has_next async def _read_load_local_packet(self, first_packet): load_packet = LoadLocalPacketWrapper(first_packet) sender = LoadLocalFile(load_packet.filename, self.connection) try: await sender.send_data() except Exception: # Skip ok packet await self.connection._read_packet() raise ok_packet = await self.connection._read_packet() if not ok_packet.is_ok_packet(): raise OperationalError(2014, "Commands Out of Sync") self._read_ok_packet(ok_packet) def _check_packet_is_eof(self, packet): if packet.is_eof_packet(): eof_packet = EOFPacketWrapper(packet) self.warning_count = eof_packet.warning_count self.has_next = eof_packet.has_next return True return False async def _read_result_packet(self, first_packet): self.field_count = first_packet.read_length_encoded_integer() await self._get_descriptions() await self._read_rowdata_packet() async def _read_rowdata_packet_unbuffered(self): # Check if in an active query if not self.unbuffered_active: return packet = await self.connection._read_packet() if self._check_packet_is_eof(packet): self.unbuffered_active = False self.connection = None self.rows = None return row = self._read_row_from_packet(packet) self.affected_rows = 1 # rows should tuple of row for MySQL-python compatibility. self.rows = (row,) return row async def _finish_unbuffered_query(self): # After much reading on the MySQL protocol, it appears that there is, # in fact, no way to stop MySQL from sending all the data after # executing a query, so we just spin, and wait for an EOF packet. while self.unbuffered_active: try: packet = await self.connection._read_packet() except OperationalError as e: # TODO: replace these numbers with constants when available # TODO: in a new PyMySQL release if e.args[0] in ( 3024, # ER.QUERY_TIMEOUT 1969, # ER.STATEMENT_TIMEOUT ): # if the query timed out we can simply ignore this error self.unbuffered_active = False self.connection = None return raise if self._check_packet_is_eof(packet): self.unbuffered_active = False # release reference to kill cyclic reference. self.connection = None async def _read_rowdata_packet(self): """Read a rowdata packet for each data row in the result set.""" rows = [] while True: packet = await self.connection._read_packet() if self._check_packet_is_eof(packet): # release reference to kill cyclic reference. self.connection = None break rows.append(self._read_row_from_packet(packet)) self.affected_rows = len(rows) self.rows = tuple(rows) def _read_row_from_packet(self, packet): row = [] for encoding, converter in self.converters: try: data = packet.read_length_coded_string() except IndexError: # No more columns in this row # See https://github.com/PyMySQL/PyMySQL/pull/434 break if data is not None: if encoding is not None: data = data.decode(encoding) if converter is not None: data = converter(data) row.append(data) return tuple(row) async def _get_descriptions(self): """Read a column descriptor packet for each column in the result.""" self.fields = [] self.converters = [] use_unicode = self.connection.use_unicode conn_encoding = self.connection.encoding description = [] for i in range(self.field_count): field = await self.connection._read_packet( FieldDescriptorPacket) self.fields.append(field) description.append(field.description()) field_type = field.type_code if use_unicode: if field_type == FIELD_TYPE.JSON: # When SELECT from JSON column: charset = binary # When SELECT CAST(... AS JSON): charset = connection # encoding # This behavior is different from TEXT / BLOB. # We should decode result by connection encoding # regardless charsetnr. # See https://github.com/PyMySQL/PyMySQL/issues/488 encoding = conn_encoding # SELECT CAST(... AS JSON) elif field_type in TEXT_TYPES: if field.charsetnr == 63: # binary # TEXTs with charset=binary means BINARY types. encoding = None else: encoding = conn_encoding else: # Integers, Dates and Times, and other basic data # is encoded in ascii encoding = 'ascii' else: encoding = None converter = self.connection.decoders.get(field_type) if converter is through: converter = None self.converters.append((encoding, converter)) eof_packet = await self.connection._read_packet() assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF' self.description = tuple(description) class LoadLocalFile: def __init__(self, filename, connection): self.filename = filename self.connection = connection self._loop = connection.loop self._file_object = None self._executor = None # means use default executor def _open_file(self): def opener(filename): try: self._file_object = open(filename, 'rb') except OSError as e: msg = f"Can't find file '{filename}'" raise OperationalError(1017, msg) from e fut = self._loop.run_in_executor(self._executor, opener, self.filename) return fut def _file_read(self, chunk_size): def freader(chunk_size): try: chunk = self._file_object.read(chunk_size) if not chunk: self._file_object.close() self._file_object = None except Exception as e: self._file_object.close() self._file_object = None msg = f"Error reading file {self.filename}" raise OperationalError(1024, msg) from e return chunk fut = self._loop.run_in_executor(self._executor, freader, chunk_size) return fut async def send_data(self): """Send data packets from the local file to the server""" self.connection._ensure_alive() conn = self.connection try: await self._open_file() with self._file_object: chunk_size = MAX_PACKET_LEN while True: chunk = await self._file_read(chunk_size) if not chunk: break # TODO: consider drain data conn.write_packet(chunk) except asyncio.CancelledError: self.connection._close_on_cancel() raise finally: # send the empty packet to signify we are done sending data conn.write_packet(b"")