1415 lines
51 KiB
Python
1415 lines
51 KiB
Python
# Python implementation of the MySQL client-server protocol
|
|
# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
|
|
|
import asyncio
|
|
import os
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import warnings
|
|
import configparser
|
|
import getpass
|
|
from functools import partial
|
|
|
|
from pymysql.charset import charset_by_name, charset_by_id
|
|
from pymysql.constants import SERVER_STATUS
|
|
from pymysql.constants import CLIENT
|
|
from pymysql.constants import COMMAND
|
|
from pymysql.constants import CR
|
|
from pymysql.constants import FIELD_TYPE
|
|
from pymysql.converters import (escape_item, encoders, decoders,
|
|
escape_string, escape_bytes_prefixed, through)
|
|
from pymysql.err import (Warning, Error,
|
|
InterfaceError, DataError, DatabaseError,
|
|
OperationalError,
|
|
IntegrityError, InternalError, NotSupportedError,
|
|
ProgrammingError)
|
|
|
|
from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET
|
|
from pymysql.connections import _auth
|
|
|
|
from pymysql.connections import MysqlPacket
|
|
from pymysql.connections import FieldDescriptorPacket
|
|
from pymysql.connections import EOFPacketWrapper
|
|
from pymysql.connections import OKPacketWrapper
|
|
from pymysql.connections import LoadLocalPacketWrapper
|
|
|
|
# from aiomysql.utils import _convert_to_str
|
|
from .cursors import Cursor
|
|
from .utils import _pack_int24, _lenenc_int, _ConnectionContextManager, _ContextManager
|
|
from .log import logger
|
|
|
|
try:
|
|
DEFAULT_USER = getpass.getuser()
|
|
except KeyError:
|
|
DEFAULT_USER = "unknown"
|
|
|
|
|
|
def connect(host="localhost", user=None, password="",
|
|
db=None, port=3306, unix_socket=None,
|
|
charset='', sql_mode=None,
|
|
read_default_file=None, conv=decoders, use_unicode=None,
|
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
|
connect_timeout=None, read_default_group=None,
|
|
autocommit=False, echo=False,
|
|
local_infile=False, loop=None, ssl=None, auth_plugin='',
|
|
program_name='', server_public_key=None):
|
|
"""See connections.Connection.__init__() for information about
|
|
defaults."""
|
|
coro = _connect(host=host, user=user, password=password, db=db,
|
|
port=port, unix_socket=unix_socket, charset=charset,
|
|
sql_mode=sql_mode, read_default_file=read_default_file,
|
|
conv=conv, use_unicode=use_unicode,
|
|
client_flag=client_flag, cursorclass=cursorclass,
|
|
init_command=init_command,
|
|
connect_timeout=connect_timeout,
|
|
read_default_group=read_default_group,
|
|
autocommit=autocommit, echo=echo,
|
|
local_infile=local_infile, loop=loop, ssl=ssl,
|
|
auth_plugin=auth_plugin, program_name=program_name)
|
|
return _ConnectionContextManager(coro)
|
|
|
|
|
|
async def _connect(*args, **kwargs):
|
|
conn = Connection(*args, **kwargs)
|
|
await conn._connect()
|
|
return conn
|
|
|
|
|
|
async def _open_connection(host=None, port=None, **kwds):
|
|
"""This is based on asyncio.open_connection, allowing us to use a custom
|
|
StreamReader.
|
|
|
|
`limit` arg has been removed as we don't currently use it.
|
|
"""
|
|
loop = asyncio.events.get_running_loop()
|
|
reader = _StreamReader(loop=loop)
|
|
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
|
|
transport, _ = await loop.create_connection(
|
|
lambda: protocol, host, port, **kwds)
|
|
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
|
|
return reader, writer
|
|
|
|
|
|
async def _open_unix_connection(path=None, **kwds):
|
|
"""This is based on asyncio.open_unix_connection, allowing us to use a custom
|
|
StreamReader.
|
|
|
|
`limit` arg has been removed as we don't currently use it.
|
|
"""
|
|
loop = asyncio.events.get_running_loop()
|
|
|
|
reader = _StreamReader(loop=loop)
|
|
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
|
|
transport, _ = await loop.create_unix_connection(
|
|
lambda: protocol, path, **kwds)
|
|
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
|
|
return reader, writer
|
|
|
|
|
|
class _StreamReader(asyncio.StreamReader):
|
|
"""This StreamReader exposes whether EOF was received, allowing us to
|
|
discard the associated connection instead of returning it from the pool
|
|
when checking free connections in Pool._fill_free_pool().
|
|
|
|
`limit` arg has been removed as we don't currently use it.
|
|
"""
|
|
def __init__(self, loop=None):
|
|
self._eof_received = False
|
|
super().__init__(loop=loop)
|
|
|
|
def feed_eof(self) -> None:
|
|
self._eof_received = True
|
|
super().feed_eof()
|
|
|
|
@property
|
|
def eof_received(self):
|
|
return self._eof_received
|
|
|
|
|
|
class Connection:
|
|
"""Representation of a socket with a mysql server.
|
|
|
|
The proper way to get an instance of this class is to call
|
|
connect().
|
|
"""
|
|
|
|
def __init__(self, host="localhost", user=None, password="",
|
|
db=None, port=3306, unix_socket=None,
|
|
charset='', sql_mode=None,
|
|
read_default_file=None, conv=decoders, use_unicode=None,
|
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
|
connect_timeout=None, read_default_group=None,
|
|
autocommit=False, echo=False,
|
|
local_infile=False, loop=None, ssl=None, auth_plugin='',
|
|
program_name='', server_public_key=None):
|
|
"""
|
|
Establish a connection to the MySQL database. Accepts several
|
|
arguments:
|
|
|
|
:param host: Host where the database server is located
|
|
:param user: Username to log in as
|
|
:param password: Password to use.
|
|
:param db: Database to use, None to not use a particular one.
|
|
:param port: MySQL port to use, default is usually OK.
|
|
:param unix_socket: Optionally, you can use a unix socket rather
|
|
than TCP/IP.
|
|
:param charset: Charset you want to use.
|
|
:param sql_mode: Default SQL_MODE to use.
|
|
:param read_default_file: Specifies my.cnf file to read these
|
|
parameters from under the [client] section.
|
|
:param conv: Decoders dictionary to use instead of the default one.
|
|
This is used to provide custom marshalling of types.
|
|
See converters.
|
|
:param use_unicode: Whether or not to default to unicode strings.
|
|
:param client_flag: Custom flags to send to MySQL. Find
|
|
potential values in constants.CLIENT.
|
|
:param cursorclass: Custom cursor class to use.
|
|
:param init_command: Initial SQL statement to run when connection is
|
|
established.
|
|
:param connect_timeout: Timeout before throwing an exception
|
|
when connecting.
|
|
:param read_default_group: Group to read from in the configuration
|
|
file.
|
|
:param autocommit: Autocommit mode. None means use server default.
|
|
(default: False)
|
|
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
|
|
command. (default: False)
|
|
:param ssl: Optional SSL Context to force SSL
|
|
:param auth_plugin: String to manually specify the authentication
|
|
plugin to use, i.e you will want to use mysql_clear_password
|
|
when using IAM authentication with Amazon RDS.
|
|
(default: Server Default)
|
|
:param program_name: Program name string to provide when
|
|
handshaking with MySQL. (omitted by default)
|
|
:param server_public_key: SHA256 authentication plugin public
|
|
key value.
|
|
:param loop: asyncio loop
|
|
"""
|
|
self._loop = loop or asyncio.get_event_loop()
|
|
|
|
if use_unicode is None and sys.version_info[0] > 2:
|
|
use_unicode = True
|
|
|
|
if read_default_file:
|
|
if not read_default_group:
|
|
read_default_group = "client"
|
|
cfg = configparser.RawConfigParser()
|
|
cfg.read(os.path.expanduser(read_default_file))
|
|
_config = partial(cfg.get, read_default_group)
|
|
|
|
user = _config("user", fallback=user)
|
|
password = _config("password", fallback=password)
|
|
host = _config("host", fallback=host)
|
|
db = _config("database", fallback=db)
|
|
unix_socket = _config("socket", fallback=unix_socket)
|
|
port = int(_config("port", fallback=port))
|
|
charset = _config("default-character-set", fallback=charset)
|
|
|
|
self._host = host
|
|
self._port = port
|
|
self._user = user or DEFAULT_USER
|
|
self._password = password or ""
|
|
self._db = db
|
|
self._echo = echo
|
|
self._last_usage = self._loop.time()
|
|
self._client_auth_plugin = auth_plugin
|
|
self._server_auth_plugin = ""
|
|
self._auth_plugin_used = ""
|
|
self._secure = False
|
|
self.server_public_key = server_public_key
|
|
self.salt = None
|
|
|
|
from . import __version__
|
|
self._connect_attrs = {
|
|
'_client_name': 'aiomysql',
|
|
'_pid': str(os.getpid()),
|
|
'_client_version': __version__,
|
|
}
|
|
if program_name:
|
|
self._connect_attrs["program_name"] = program_name
|
|
|
|
self._unix_socket = unix_socket
|
|
if charset:
|
|
self._charset = charset
|
|
self.use_unicode = True
|
|
else:
|
|
self._charset = DEFAULT_CHARSET
|
|
self.use_unicode = False
|
|
|
|
if use_unicode is not None:
|
|
self.use_unicode = use_unicode
|
|
|
|
self._ssl_context = ssl
|
|
if ssl:
|
|
client_flag |= CLIENT.SSL
|
|
|
|
self._encoding = charset_by_name(self._charset).encoding
|
|
|
|
if local_infile:
|
|
client_flag |= CLIENT.LOCAL_FILES
|
|
|
|
client_flag |= CLIENT.CAPABILITIES
|
|
client_flag |= CLIENT.MULTI_STATEMENTS
|
|
if self._db:
|
|
client_flag |= CLIENT.CONNECT_WITH_DB
|
|
self.client_flag = client_flag
|
|
|
|
self.cursorclass = cursorclass
|
|
self.connect_timeout = connect_timeout
|
|
|
|
self._result = None
|
|
self._affected_rows = 0
|
|
self.host_info = "Not connected"
|
|
|
|
#: specified autocommit mode. None means use server default.
|
|
self.autocommit_mode = autocommit
|
|
|
|
self.encoders = encoders # Need for MySQLdb compatibility.
|
|
self.decoders = conv
|
|
self.sql_mode = sql_mode
|
|
self.init_command = init_command
|
|
|
|
# asyncio StreamReader, StreamWriter
|
|
self._reader = None
|
|
self._writer = None
|
|
# If connection was closed for specific reason, we should show that to
|
|
# user
|
|
self._close_reason = None
|
|
|
|
@property
|
|
def host(self):
|
|
"""MySQL server IP address or name"""
|
|
return self._host
|
|
|
|
@property
|
|
def port(self):
|
|
"""MySQL server TCP/IP port"""
|
|
return self._port
|
|
|
|
@property
|
|
def unix_socket(self):
|
|
"""MySQL Unix socket file location"""
|
|
return self._unix_socket
|
|
|
|
@property
|
|
def db(self):
|
|
"""Current database name."""
|
|
return self._db
|
|
|
|
@property
|
|
def user(self):
|
|
"""User used while connecting to MySQL"""
|
|
return self._user
|
|
|
|
@property
|
|
def echo(self):
|
|
"""Return echo mode status."""
|
|
return self._echo
|
|
|
|
@property
|
|
def last_usage(self):
|
|
"""Return time() when connection was used."""
|
|
return self._last_usage
|
|
|
|
@property
|
|
def loop(self):
|
|
return self._loop
|
|
|
|
@property
|
|
def closed(self):
|
|
"""The readonly property that returns ``True`` if connections is
|
|
closed.
|
|
"""
|
|
return self._writer is None
|
|
|
|
@property
|
|
def encoding(self):
|
|
"""Encoding employed for this connection."""
|
|
return self._encoding
|
|
|
|
@property
|
|
def charset(self):
|
|
"""Returns the character set for current connection."""
|
|
return self._charset
|
|
|
|
def close(self):
|
|
"""Close socket connection"""
|
|
if self._writer:
|
|
self._writer.transport.close()
|
|
self._writer = None
|
|
self._reader = None
|
|
|
|
async def ensure_closed(self):
|
|
"""Send quit command and then close socket connection"""
|
|
if self._writer is None:
|
|
# connection has been closed
|
|
return
|
|
send_data = struct.pack('<i', 1) + bytes([COMMAND.COM_QUIT])
|
|
self._writer.write(send_data)
|
|
await self._writer.drain()
|
|
self.close()
|
|
|
|
async def autocommit(self, value):
|
|
"""Enable/disable autocommit mode for current MySQL session.
|
|
|
|
:param value: ``bool``, toggle autocommit
|
|
"""
|
|
self.autocommit_mode = bool(value)
|
|
current = self.get_autocommit()
|
|
if value != current:
|
|
await self._send_autocommit_mode()
|
|
|
|
def get_autocommit(self):
|
|
"""Returns autocommit status for current MySQL session.
|
|
|
|
:returns bool: current autocommit status."""
|
|
|
|
status = self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT
|
|
return bool(status)
|
|
|
|
async def _read_ok_packet(self):
|
|
pkt = await self._read_packet()
|
|
if not pkt.is_ok_packet():
|
|
raise OperationalError(2014, "Command Out of Sync")
|
|
ok = OKPacketWrapper(pkt)
|
|
self.server_status = ok.server_status
|
|
return True
|
|
|
|
async def _send_autocommit_mode(self):
|
|
"""Set whether or not to commit after every execute() """
|
|
await self._execute_command(
|
|
COMMAND.COM_QUERY,
|
|
"SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode))
|
|
await self._read_ok_packet()
|
|
|
|
async def begin(self):
|
|
"""Begin transaction."""
|
|
await self._execute_command(COMMAND.COM_QUERY, "BEGIN")
|
|
await self._read_ok_packet()
|
|
|
|
async def commit(self):
|
|
"""Commit changes to stable storage."""
|
|
await self._execute_command(COMMAND.COM_QUERY, "COMMIT")
|
|
await self._read_ok_packet()
|
|
|
|
async def rollback(self):
|
|
"""Roll back the current transaction."""
|
|
await self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
|
|
await self._read_ok_packet()
|
|
|
|
async def select_db(self, db):
|
|
"""Set current db"""
|
|
await self._execute_command(COMMAND.COM_INIT_DB, db)
|
|
await self._read_ok_packet()
|
|
|
|
async def show_warnings(self):
|
|
"""SHOW WARNINGS"""
|
|
await self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
|
|
result = MySQLResult(self)
|
|
await result.read()
|
|
return result.rows
|
|
|
|
def escape(self, obj):
|
|
""" Escape whatever value you pass to it"""
|
|
if isinstance(obj, str):
|
|
return "'" + self.escape_string(obj) + "'"
|
|
if isinstance(obj, bytes):
|
|
return escape_bytes_prefixed(obj)
|
|
return escape_item(obj, self._charset)
|
|
|
|
def literal(self, obj):
|
|
"""Alias for escape()"""
|
|
return self.escape(obj)
|
|
|
|
def escape_string(self, s):
|
|
if (self.server_status &
|
|
SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES):
|
|
return s.replace("'", "''")
|
|
return escape_string(s)
|
|
|
|
def cursor(self, *cursors):
|
|
"""Instantiates and returns a cursor
|
|
|
|
By default, :class:`Cursor` is returned. It is possible to also give a
|
|
custom cursor through the cursor_class parameter, but it needs to
|
|
be a subclass of :class:`Cursor`
|
|
|
|
:param cursor: custom cursor class.
|
|
:returns: instance of cursor, by default :class:`Cursor`
|
|
:raises TypeError: cursor_class is not a subclass of Cursor.
|
|
"""
|
|
self._ensure_alive()
|
|
self._last_usage = self._loop.time()
|
|
try:
|
|
if cursors and \
|
|
any(not issubclass(cursor, Cursor) for cursor in cursors):
|
|
raise TypeError('Custom cursor must be subclass of Cursor')
|
|
except TypeError:
|
|
raise TypeError('Custom cursor must be subclass of Cursor')
|
|
if cursors and len(cursors) == 1:
|
|
cur = cursors[0](self, self._echo)
|
|
elif cursors:
|
|
cursor_name = ''.join(map(lambda x: x.__name__, cursors)) \
|
|
.replace('Cursor', '') + 'Cursor'
|
|
cursor_class = type(cursor_name, cursors, {})
|
|
cur = cursor_class(self, self._echo)
|
|
else:
|
|
cur = self.cursorclass(self, self._echo)
|
|
fut = self._loop.create_future()
|
|
fut.set_result(cur)
|
|
return _ContextManager(fut)
|
|
|
|
# The following methods are INTERNAL USE ONLY (called from Cursor)
|
|
async def query(self, sql, unbuffered=False):
|
|
# logger.debug("DEBUG: sending query: %s", _convert_to_str(sql))
|
|
if isinstance(sql, str):
|
|
sql = sql.encode(self.encoding, 'surrogateescape')
|
|
await self._execute_command(COMMAND.COM_QUERY, sql)
|
|
await self._read_query_result(unbuffered=unbuffered)
|
|
return self._affected_rows
|
|
|
|
async def next_result(self):
|
|
await self._read_query_result()
|
|
return self._affected_rows
|
|
|
|
def affected_rows(self):
|
|
return self._affected_rows
|
|
|
|
async def kill(self, thread_id):
|
|
arg = struct.pack('<I', thread_id)
|
|
await self._execute_command(COMMAND.COM_PROCESS_KILL, arg)
|
|
await self._read_ok_packet()
|
|
|
|
async def ping(self, reconnect=True):
|
|
"""Check if the server is alive"""
|
|
if self._writer is None and self._reader is None:
|
|
if reconnect:
|
|
await self._connect()
|
|
reconnect = False
|
|
else:
|
|
raise Error("Already closed")
|
|
try:
|
|
await self._execute_command(COMMAND.COM_PING, "")
|
|
await self._read_ok_packet()
|
|
except Exception:
|
|
if reconnect:
|
|
await self._connect()
|
|
await self.ping(False)
|
|
else:
|
|
raise
|
|
|
|
async def set_charset(self, charset):
|
|
"""Sets the character set for the current connection"""
|
|
# Make sure charset is supported.
|
|
encoding = charset_by_name(charset).encoding
|
|
await self._execute_command(COMMAND.COM_QUERY, "SET NAMES %s"
|
|
% self.escape(charset))
|
|
await self._read_packet()
|
|
self._charset = charset
|
|
self._encoding = encoding
|
|
|
|
async def _connect(self):
|
|
# TODO: Set close callback
|
|
# raise OperationalError(CR.CR_SERVER_GONE_ERROR,
|
|
# "MySQL server has gone away (%r)" % (e,))
|
|
try:
|
|
if self._unix_socket:
|
|
self._reader, self._writer = await \
|
|
asyncio.wait_for(
|
|
_open_unix_connection(
|
|
self._unix_socket),
|
|
timeout=self.connect_timeout)
|
|
self.host_info = "Localhost via UNIX socket: " + \
|
|
self._unix_socket
|
|
self._secure = True
|
|
else:
|
|
self._reader, self._writer = await \
|
|
asyncio.wait_for(
|
|
_open_connection(
|
|
self._host,
|
|
self._port),
|
|
timeout=self.connect_timeout)
|
|
self._set_keep_alive()
|
|
self._set_nodelay(True)
|
|
self.host_info = "socket %s:%d" % (self._host, self._port)
|
|
|
|
self._next_seq_id = 0
|
|
|
|
await self._get_server_information()
|
|
await self._request_authentication()
|
|
|
|
self.connected_time = self._loop.time()
|
|
|
|
if self.sql_mode is not None:
|
|
await self.query(f"SET sql_mode={self.sql_mode}")
|
|
|
|
if self.init_command is not None:
|
|
await self.query(self.init_command)
|
|
await self.commit()
|
|
|
|
if self.autocommit_mode is not None:
|
|
await self.autocommit(self.autocommit_mode)
|
|
except Exception as e:
|
|
if self._writer:
|
|
self._writer.transport.close()
|
|
self._reader = None
|
|
self._writer = None
|
|
|
|
# As of 3.11, asyncio.TimeoutError is a deprecated alias of
|
|
# OSError. For consistency, we're also considering this an
|
|
# OperationalError on earlier python versions.
|
|
if isinstance(e, (IOError, OSError, asyncio.TimeoutError)):
|
|
raise OperationalError(
|
|
CR.CR_CONN_HOST_ERROR,
|
|
"Can't connect to MySQL server on %r" % self._host,
|
|
) from e
|
|
|
|
# If e is neither IOError nor OSError, it's a bug.
|
|
# Raising AssertionError would hide the original error, so we just
|
|
# reraise it.
|
|
raise
|
|
|
|
def _set_keep_alive(self):
|
|
transport = self._writer.transport
|
|
transport.pause_reading()
|
|
raw_sock = transport.get_extra_info('socket', default=None)
|
|
if raw_sock is None:
|
|
raise RuntimeError("Transport does not expose socket instance")
|
|
raw_sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
transport.resume_reading()
|
|
|
|
def _set_nodelay(self, value):
|
|
flag = int(bool(value))
|
|
transport = self._writer.transport
|
|
transport.pause_reading()
|
|
raw_sock = transport.get_extra_info('socket', default=None)
|
|
if raw_sock is None:
|
|
raise RuntimeError("Transport does not expose socket instance")
|
|
raw_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, flag)
|
|
transport.resume_reading()
|
|
|
|
def write_packet(self, payload):
|
|
"""Writes an entire "mysql packet" in its entirety to the network
|
|
addings its length and sequence number.
|
|
"""
|
|
# Internal note: when you build packet manually and calls
|
|
# _write_bytes() directly, you should set self._next_seq_id properly.
|
|
data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
|
|
self._write_bytes(data)
|
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
|
|
|
async def _read_packet(self, packet_type=MysqlPacket):
|
|
"""Read an entire "mysql packet" in its entirety from the network
|
|
and return a MysqlPacket type that represents the results.
|
|
"""
|
|
buff = b''
|
|
while True:
|
|
try:
|
|
packet_header = await self._read_bytes(4)
|
|
except asyncio.CancelledError:
|
|
self._close_on_cancel()
|
|
raise
|
|
|
|
btrl, btrh, packet_number = struct.unpack(
|
|
'<HBB', packet_header)
|
|
bytes_to_read = btrl + (btrh << 16)
|
|
|
|
# Outbound and inbound packets are numbered sequentialy, so
|
|
# we increment in both write_packet and read_packet. The count
|
|
# is reset at new COMMAND PHASE.
|
|
if packet_number != self._next_seq_id:
|
|
self.close()
|
|
if packet_number == 0:
|
|
# MySQL 8.0 sends error packet with seqno==0 when shutdown
|
|
raise OperationalError(
|
|
CR.CR_SERVER_LOST,
|
|
"Lost connection to MySQL server during query")
|
|
|
|
raise InternalError(
|
|
"Packet sequence number wrong - got %d expected %d" %
|
|
(packet_number, self._next_seq_id))
|
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
|
|
|
try:
|
|
recv_data = await self._read_bytes(bytes_to_read)
|
|
except asyncio.CancelledError:
|
|
self._close_on_cancel()
|
|
raise
|
|
|
|
buff += recv_data
|
|
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
|
if bytes_to_read == 0xffffff:
|
|
continue
|
|
if bytes_to_read < MAX_PACKET_LEN:
|
|
break
|
|
|
|
packet = packet_type(buff, self._encoding)
|
|
if packet.is_error_packet():
|
|
if self._result is not None and \
|
|
self._result.unbuffered_active is True:
|
|
self._result.unbuffered_active = False
|
|
packet.raise_for_error()
|
|
return packet
|
|
|
|
async def _read_bytes(self, num_bytes):
|
|
try:
|
|
data = await self._reader.readexactly(num_bytes)
|
|
except asyncio.IncompleteReadError as e:
|
|
msg = "Lost connection to MySQL server during query"
|
|
self.close()
|
|
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
|
|
except OSError as e:
|
|
msg = f"Lost connection to MySQL server during query ({e})"
|
|
self.close()
|
|
raise OperationalError(CR.CR_SERVER_LOST, msg) from e
|
|
return data
|
|
|
|
def _write_bytes(self, data):
|
|
return self._writer.write(data)
|
|
|
|
async def _read_query_result(self, unbuffered=False):
|
|
self._result = None
|
|
if unbuffered:
|
|
try:
|
|
result = MySQLResult(self)
|
|
await result.init_unbuffered_query()
|
|
except BaseException:
|
|
result.unbuffered_active = False
|
|
result.connection = None
|
|
raise
|
|
else:
|
|
result = MySQLResult(self)
|
|
await result.read()
|
|
self._result = result
|
|
self._affected_rows = result.affected_rows
|
|
if result.server_status is not None:
|
|
self.server_status = result.server_status
|
|
|
|
def insert_id(self):
|
|
if self._result:
|
|
return self._result.insert_id
|
|
else:
|
|
return 0
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type:
|
|
self.close()
|
|
else:
|
|
await self.ensure_closed()
|
|
return
|
|
|
|
async def _execute_command(self, command, sql):
|
|
self._ensure_alive()
|
|
|
|
# If the last query was unbuffered, make sure it finishes before
|
|
# sending new commands
|
|
if self._result is not None:
|
|
if self._result.unbuffered_active:
|
|
warnings.warn("Previous unbuffered result was left incomplete")
|
|
await self._result._finish_unbuffered_query()
|
|
while self._result.has_next:
|
|
await self.next_result()
|
|
self._result = None
|
|
|
|
if isinstance(sql, str):
|
|
sql = sql.encode(self._encoding)
|
|
|
|
chunk_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
|
|
|
|
prelude = struct.pack('<iB', chunk_size, command)
|
|
self._write_bytes(prelude + sql[:chunk_size - 1])
|
|
# logger.debug(dump_packet(prelude + sql))
|
|
self._next_seq_id = 1
|
|
|
|
if chunk_size < MAX_PACKET_LEN:
|
|
return
|
|
|
|
sql = sql[chunk_size - 1:]
|
|
while True:
|
|
chunk_size = min(MAX_PACKET_LEN, len(sql))
|
|
self.write_packet(sql[:chunk_size])
|
|
sql = sql[chunk_size:]
|
|
if not sql and chunk_size < MAX_PACKET_LEN:
|
|
break
|
|
|
|
async def _request_authentication(self):
|
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
|
if int(self.server_version.split('.', 1)[0]) >= 5:
|
|
self.client_flag |= CLIENT.MULTI_RESULTS
|
|
|
|
if self.user is None:
|
|
raise ValueError("Did not specify a username")
|
|
|
|
charset_id = charset_by_name(self.charset).id
|
|
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
|
|
charset_id, b'')
|
|
|
|
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
|
|
self.write_packet(data_init)
|
|
|
|
# Stop sending events to data_received
|
|
self._writer.transport.pause_reading()
|
|
|
|
# Get the raw socket from the transport
|
|
raw_sock = self._writer.transport.get_extra_info('socket',
|
|
default=None)
|
|
if raw_sock is None:
|
|
raise RuntimeError("Transport does not expose socket instance")
|
|
|
|
raw_sock = raw_sock.dup()
|
|
self._writer.transport.close()
|
|
# MySQL expects TLS negotiation to happen in the middle of a
|
|
# TCP connection not at start. Passing in a socket to
|
|
# open_connection will cause it to negotiate TLS on an existing
|
|
# connection not initiate a new one.
|
|
self._reader, self._writer = await _open_connection(
|
|
sock=raw_sock, ssl=self._ssl_context,
|
|
server_hostname=self._host
|
|
)
|
|
|
|
self._secure = True
|
|
|
|
if isinstance(self.user, str):
|
|
_user = self.user.encode(self.encoding)
|
|
else:
|
|
_user = self.user
|
|
|
|
data = data_init + _user + b'\0'
|
|
|
|
authresp = b''
|
|
|
|
auth_plugin = self._client_auth_plugin
|
|
if not self._client_auth_plugin:
|
|
# Contains the auth plugin from handshake
|
|
auth_plugin = self._server_auth_plugin
|
|
|
|
if auth_plugin in ('', 'mysql_native_password'):
|
|
authresp = _auth.scramble_native_password(
|
|
self._password.encode('latin1'), self.salt)
|
|
elif auth_plugin == 'caching_sha2_password':
|
|
if self._password:
|
|
authresp = _auth.scramble_caching_sha2(
|
|
self._password.encode('latin1'), self.salt
|
|
)
|
|
# Else: empty password
|
|
elif auth_plugin == 'sha256_password':
|
|
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
|
|
authresp = self._password.encode('latin1') + b'\0'
|
|
elif self._password:
|
|
authresp = b'\1' # request public key
|
|
else:
|
|
authresp = b'\0' # empty password
|
|
|
|
elif auth_plugin in ('', 'mysql_clear_password'):
|
|
authresp = self._password.encode('latin1') + b'\0'
|
|
|
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
|
data += _lenenc_int(len(authresp)) + authresp
|
|
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
|
data += struct.pack('B', len(authresp)) + authresp
|
|
else: # pragma: no cover
|
|
# not testing against servers without secure auth (>=5.0)
|
|
data += authresp + b'\0'
|
|
|
|
if self._db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
|
|
|
if isinstance(self._db, str):
|
|
db = self._db.encode(self.encoding)
|
|
else:
|
|
db = self._db
|
|
data += db + b'\0'
|
|
|
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
|
name = auth_plugin
|
|
if isinstance(name, str):
|
|
name = name.encode('ascii')
|
|
data += name + b'\0'
|
|
|
|
self._auth_plugin_used = auth_plugin
|
|
|
|
# Sends the server a few pieces of client info
|
|
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
|
|
connect_attrs = b''
|
|
for k, v in self._connect_attrs.items():
|
|
k, v = k.encode('utf8'), v.encode('utf8')
|
|
connect_attrs += struct.pack('B', len(k)) + k
|
|
connect_attrs += struct.pack('B', len(v)) + v
|
|
data += struct.pack('B', len(connect_attrs)) + connect_attrs
|
|
|
|
self.write_packet(data)
|
|
auth_packet = await self._read_packet()
|
|
|
|
# if authentication method isn't accepted the first byte
|
|
# will have the octet 254
|
|
if auth_packet.is_auth_switch_request():
|
|
# https://dev.mysql.com/doc/internals/en/
|
|
# connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
|
auth_packet.read_uint8() # 0xfe packet identifier
|
|
plugin_name = auth_packet.read_string()
|
|
if (self.server_capabilities & CLIENT.PLUGIN_AUTH and
|
|
plugin_name is not None):
|
|
await self._process_auth(plugin_name, auth_packet)
|
|
else:
|
|
# send legacy handshake
|
|
data = _auth.scramble_old_password(
|
|
self._password.encode('latin1'),
|
|
auth_packet.read_all()) + b'\0'
|
|
self.write_packet(data)
|
|
await self._read_packet()
|
|
elif auth_packet.is_extra_auth_data():
|
|
if auth_plugin == "caching_sha2_password":
|
|
await self.caching_sha2_password_auth(auth_packet)
|
|
elif auth_plugin == "sha256_password":
|
|
await self.sha256_password_auth(auth_packet)
|
|
else:
|
|
raise OperationalError("Received extra packet "
|
|
"for auth method %r", auth_plugin)
|
|
|
|
async def _process_auth(self, plugin_name, auth_packet):
|
|
# These auth plugins do their own packet handling
|
|
if plugin_name == b"caching_sha2_password":
|
|
await self.caching_sha2_password_auth(auth_packet)
|
|
self._auth_plugin_used = plugin_name.decode()
|
|
elif plugin_name == b"sha256_password":
|
|
await self.sha256_password_auth(auth_packet)
|
|
self._auth_plugin_used = plugin_name.decode()
|
|
else:
|
|
|
|
if plugin_name == b"mysql_native_password":
|
|
# https://dev.mysql.com/doc/internals/en/
|
|
# secure-password-authentication.html#packet-Authentication::
|
|
# Native41
|
|
data = _auth.scramble_native_password(
|
|
self._password.encode('latin1'),
|
|
auth_packet.read_all())
|
|
elif plugin_name == b"mysql_old_password":
|
|
# https://dev.mysql.com/doc/internals/en/
|
|
# old-password-authentication.html
|
|
data = _auth.scramble_old_password(
|
|
self._password.encode('latin1'),
|
|
auth_packet.read_all()
|
|
) + b'\0'
|
|
elif plugin_name == b"mysql_clear_password":
|
|
# https://dev.mysql.com/doc/internals/en/
|
|
# clear-text-authentication.html
|
|
data = self._password.encode('latin1') + b'\0'
|
|
else:
|
|
raise OperationalError(
|
|
2059, "Authentication plugin '{}'"
|
|
" not configured".format(plugin_name)
|
|
)
|
|
|
|
self.write_packet(data)
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
|
|
self._auth_plugin_used = plugin_name.decode()
|
|
|
|
return pkt
|
|
|
|
async def caching_sha2_password_auth(self, pkt):
|
|
# No password fast path
|
|
if not self._password:
|
|
self.write_packet(b'')
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
return pkt
|
|
|
|
if pkt.is_auth_switch_request():
|
|
# Try from fast auth
|
|
logger.debug("caching sha2: Trying fast path")
|
|
self.salt = pkt.read_all()
|
|
scrambled = _auth.scramble_caching_sha2(
|
|
self._password.encode('latin1'), self.salt
|
|
)
|
|
|
|
self.write_packet(scrambled)
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
|
|
# else: fast auth is tried in initial handshake
|
|
|
|
if not pkt.is_extra_auth_data():
|
|
raise OperationalError(
|
|
"caching sha2: Unknown packet "
|
|
"for fast auth: {}".format(pkt._data[:1])
|
|
)
|
|
|
|
# magic numbers:
|
|
# 2 - request public key
|
|
# 3 - fast auth succeeded
|
|
# 4 - need full auth
|
|
|
|
pkt.advance(1)
|
|
n = pkt.read_uint8()
|
|
|
|
if n == 3:
|
|
logger.debug("caching sha2: succeeded by fast path.")
|
|
pkt = await self._read_packet()
|
|
pkt.check_error() # pkt must be OK packet
|
|
return pkt
|
|
|
|
if n != 4:
|
|
raise OperationalError("caching sha2: Unknown "
|
|
"result for fast auth: {}".format(n))
|
|
|
|
logger.debug("caching sha2: Trying full auth...")
|
|
|
|
if self._secure:
|
|
logger.debug("caching sha2: Sending plain "
|
|
"password via secure connection")
|
|
self.write_packet(self._password.encode('latin1') + b'\0')
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
return pkt
|
|
|
|
if not self.server_public_key:
|
|
self.write_packet(b'\x02')
|
|
pkt = await self._read_packet() # Request public key
|
|
pkt.check_error()
|
|
|
|
if not pkt.is_extra_auth_data():
|
|
raise OperationalError(
|
|
"caching sha2: Unknown packet "
|
|
"for public key: {}".format(pkt._data[:1])
|
|
)
|
|
|
|
self.server_public_key = pkt._data[1:]
|
|
logger.debug(self.server_public_key.decode('ascii'))
|
|
|
|
data = _auth.sha2_rsa_encrypt(
|
|
self._password.encode('latin1'), self.salt,
|
|
self.server_public_key
|
|
)
|
|
self.write_packet(data)
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
|
|
async def sha256_password_auth(self, pkt):
|
|
if self._secure:
|
|
logger.debug("sha256: Sending plain password")
|
|
data = self._password.encode('latin1') + b'\0'
|
|
self.write_packet(data)
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
return pkt
|
|
|
|
if pkt.is_auth_switch_request():
|
|
self.salt = pkt.read_all()
|
|
if not self.server_public_key and self._password:
|
|
# Request server public key
|
|
logger.debug("sha256: Requesting server public key")
|
|
self.write_packet(b'\1')
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
|
|
if pkt.is_extra_auth_data():
|
|
self.server_public_key = pkt._data[1:]
|
|
logger.debug(
|
|
"Received public key:\n%s",
|
|
self.server_public_key.decode('ascii')
|
|
)
|
|
|
|
if self._password:
|
|
if not self.server_public_key:
|
|
raise OperationalError("Couldn't receive server's public key")
|
|
|
|
data = _auth.sha2_rsa_encrypt(
|
|
self._password.encode('latin1'), self.salt,
|
|
self.server_public_key
|
|
)
|
|
else:
|
|
data = b''
|
|
|
|
self.write_packet(data)
|
|
pkt = await self._read_packet()
|
|
pkt.check_error()
|
|
return pkt
|
|
|
|
# _mysql support
|
|
def thread_id(self):
|
|
return self.server_thread_id[0]
|
|
|
|
def character_set_name(self):
|
|
return self._charset
|
|
|
|
def get_host_info(self):
|
|
return self.host_info
|
|
|
|
def get_proto_info(self):
|
|
return self.protocol_version
|
|
|
|
async def _get_server_information(self):
|
|
i = 0
|
|
packet = await self._read_packet()
|
|
data = packet.get_all_data()
|
|
# logger.debug(dump_packet(data))
|
|
self.protocol_version = data[i]
|
|
i += 1
|
|
|
|
server_end = data.find(b'\0', i)
|
|
self.server_version = data[i:server_end].decode('latin1')
|
|
i = server_end + 1
|
|
|
|
self.server_thread_id = struct.unpack('<I', data[i:i + 4])
|
|
i += 4
|
|
|
|
self.salt = data[i:i + 8]
|
|
i += 9 # 8 + 1(filler)
|
|
|
|
self.server_capabilities = struct.unpack('<H', data[i:i + 2])[0]
|
|
i += 2
|
|
|
|
if len(data) >= i + 6:
|
|
lang, stat, cap_h, salt_len = struct.unpack('<BHHB', data[i:i + 6])
|
|
i += 6
|
|
self.server_language = lang
|
|
try:
|
|
self.server_charset = charset_by_id(lang).name
|
|
except KeyError:
|
|
# unknown collation
|
|
self.server_charset = None
|
|
|
|
self.server_status = stat
|
|
# logger.debug("server_status: %s" % _convert_to_str(stat))
|
|
self.server_capabilities |= cap_h << 16
|
|
# logger.debug("salt_len: %s" % _convert_to_str(salt_len))
|
|
salt_len = max(12, salt_len - 9)
|
|
|
|
# reserved
|
|
i += 10
|
|
|
|
if len(data) >= i + salt_len:
|
|
# salt_len includes auth_plugin_data_part_1 and filler
|
|
self.salt += data[i:i + salt_len]
|
|
i += salt_len
|
|
|
|
i += 1
|
|
|
|
# AUTH PLUGIN NAME may appear here.
|
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
|
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
|
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
|
# ref: https://dev.mysql.com/doc/internals/en/
|
|
# connection-phase-packets.html#packet-Protocol::Handshake
|
|
# didn't use version checks as mariadb is corrected and reports
|
|
# earlier than those two.
|
|
server_end = data.find(b'\0', i)
|
|
if server_end < 0: # pragma: no cover - very specific upstream bug
|
|
# not found \0 and last field so take it all
|
|
self._server_auth_plugin = data[i:].decode('latin1')
|
|
else:
|
|
self._server_auth_plugin = data[i:server_end].decode('latin1')
|
|
|
|
def get_transaction_status(self):
|
|
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS)
|
|
|
|
def get_server_info(self):
|
|
return self.server_version
|
|
|
|
# Just to always have consistent errors 2 helpers
|
|
|
|
def _close_on_cancel(self):
|
|
self.close()
|
|
self._close_reason = "Cancelled during execution"
|
|
|
|
def _ensure_alive(self):
|
|
if not self._writer:
|
|
if self._close_reason is None:
|
|
raise InterfaceError("(0, 'Not connected')")
|
|
else:
|
|
raise InterfaceError(self._close_reason)
|
|
|
|
def __del__(self):
|
|
if self._writer:
|
|
warnings.warn(f"Unclosed connection {self!r}",
|
|
ResourceWarning)
|
|
self.close()
|
|
|
|
Warning = Warning
|
|
Error = Error
|
|
InterfaceError = InterfaceError
|
|
DatabaseError = DatabaseError
|
|
DataError = DataError
|
|
OperationalError = OperationalError
|
|
IntegrityError = IntegrityError
|
|
InternalError = InternalError
|
|
ProgrammingError = ProgrammingError
|
|
NotSupportedError = NotSupportedError
|
|
|
|
|
|
# TODO: move OK and EOF packet parsing/logic into a proper subclass
|
|
# of MysqlPacket like has been done with FieldDescriptorPacket.
|
|
class MySQLResult:
|
|
|
|
def __init__(self, connection):
|
|
self.connection = connection
|
|
self.affected_rows = None
|
|
self.insert_id = None
|
|
self.server_status = None
|
|
self.warning_count = 0
|
|
self.message = None
|
|
self.field_count = 0
|
|
self.description = None
|
|
self.rows = None
|
|
self.has_next = None
|
|
self.unbuffered_active = False
|
|
|
|
async def read(self):
|
|
try:
|
|
first_packet = await self.connection._read_packet()
|
|
|
|
# TODO: use classes for different packet types?
|
|
if first_packet.is_ok_packet():
|
|
self._read_ok_packet(first_packet)
|
|
elif first_packet.is_load_local_packet():
|
|
await self._read_load_local_packet(first_packet)
|
|
else:
|
|
await self._read_result_packet(first_packet)
|
|
finally:
|
|
self.connection = None
|
|
|
|
async def init_unbuffered_query(self):
|
|
self.unbuffered_active = True
|
|
first_packet = await self.connection._read_packet()
|
|
|
|
if first_packet.is_ok_packet():
|
|
self._read_ok_packet(first_packet)
|
|
self.unbuffered_active = False
|
|
self.connection = None
|
|
elif first_packet.is_load_local_packet():
|
|
await self._read_load_local_packet(first_packet)
|
|
self.unbuffered_active = False
|
|
self.connection = None
|
|
else:
|
|
self.field_count = first_packet.read_length_encoded_integer()
|
|
await self._get_descriptions()
|
|
|
|
# Apparently, MySQLdb picks this number because it's the maximum
|
|
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
|
# we set it to this instead of None, which would be preferred.
|
|
self.affected_rows = 18446744073709551615
|
|
|
|
def _read_ok_packet(self, first_packet):
|
|
ok_packet = OKPacketWrapper(first_packet)
|
|
self.affected_rows = ok_packet.affected_rows
|
|
self.insert_id = ok_packet.insert_id
|
|
self.server_status = ok_packet.server_status
|
|
self.warning_count = ok_packet.warning_count
|
|
self.message = ok_packet.message
|
|
self.has_next = ok_packet.has_next
|
|
|
|
async def _read_load_local_packet(self, first_packet):
|
|
load_packet = LoadLocalPacketWrapper(first_packet)
|
|
sender = LoadLocalFile(load_packet.filename, self.connection)
|
|
try:
|
|
await sender.send_data()
|
|
except Exception:
|
|
# Skip ok packet
|
|
await self.connection._read_packet()
|
|
raise
|
|
|
|
ok_packet = await self.connection._read_packet()
|
|
if not ok_packet.is_ok_packet():
|
|
raise OperationalError(2014, "Commands Out of Sync")
|
|
self._read_ok_packet(ok_packet)
|
|
|
|
def _check_packet_is_eof(self, packet):
|
|
if packet.is_eof_packet():
|
|
eof_packet = EOFPacketWrapper(packet)
|
|
self.warning_count = eof_packet.warning_count
|
|
self.has_next = eof_packet.has_next
|
|
return True
|
|
return False
|
|
|
|
async def _read_result_packet(self, first_packet):
|
|
self.field_count = first_packet.read_length_encoded_integer()
|
|
await self._get_descriptions()
|
|
await self._read_rowdata_packet()
|
|
|
|
async def _read_rowdata_packet_unbuffered(self):
|
|
# Check if in an active query
|
|
if not self.unbuffered_active:
|
|
return
|
|
|
|
packet = await self.connection._read_packet()
|
|
if self._check_packet_is_eof(packet):
|
|
self.unbuffered_active = False
|
|
self.connection = None
|
|
self.rows = None
|
|
return
|
|
|
|
row = self._read_row_from_packet(packet)
|
|
self.affected_rows = 1
|
|
# rows should tuple of row for MySQL-python compatibility.
|
|
self.rows = (row,)
|
|
return row
|
|
|
|
async def _finish_unbuffered_query(self):
|
|
# After much reading on the MySQL protocol, it appears that there is,
|
|
# in fact, no way to stop MySQL from sending all the data after
|
|
# executing a query, so we just spin, and wait for an EOF packet.
|
|
while self.unbuffered_active:
|
|
try:
|
|
packet = await self.connection._read_packet()
|
|
except OperationalError as e:
|
|
# TODO: replace these numbers with constants when available
|
|
# TODO: in a new PyMySQL release
|
|
if e.args[0] in (
|
|
3024, # ER.QUERY_TIMEOUT
|
|
1969, # ER.STATEMENT_TIMEOUT
|
|
):
|
|
# if the query timed out we can simply ignore this error
|
|
self.unbuffered_active = False
|
|
self.connection = None
|
|
return
|
|
|
|
raise
|
|
|
|
if self._check_packet_is_eof(packet):
|
|
self.unbuffered_active = False
|
|
# release reference to kill cyclic reference.
|
|
self.connection = None
|
|
|
|
async def _read_rowdata_packet(self):
|
|
"""Read a rowdata packet for each data row in the result set."""
|
|
rows = []
|
|
while True:
|
|
packet = await self.connection._read_packet()
|
|
if self._check_packet_is_eof(packet):
|
|
# release reference to kill cyclic reference.
|
|
self.connection = None
|
|
break
|
|
rows.append(self._read_row_from_packet(packet))
|
|
|
|
self.affected_rows = len(rows)
|
|
self.rows = tuple(rows)
|
|
|
|
def _read_row_from_packet(self, packet):
|
|
row = []
|
|
for encoding, converter in self.converters:
|
|
try:
|
|
data = packet.read_length_coded_string()
|
|
except IndexError:
|
|
# No more columns in this row
|
|
# See https://github.com/PyMySQL/PyMySQL/pull/434
|
|
break
|
|
if data is not None:
|
|
if encoding is not None:
|
|
data = data.decode(encoding)
|
|
if converter is not None:
|
|
data = converter(data)
|
|
row.append(data)
|
|
return tuple(row)
|
|
|
|
async def _get_descriptions(self):
|
|
"""Read a column descriptor packet for each column in the result."""
|
|
self.fields = []
|
|
self.converters = []
|
|
use_unicode = self.connection.use_unicode
|
|
conn_encoding = self.connection.encoding
|
|
description = []
|
|
for i in range(self.field_count):
|
|
field = await self.connection._read_packet(
|
|
FieldDescriptorPacket)
|
|
self.fields.append(field)
|
|
description.append(field.description())
|
|
field_type = field.type_code
|
|
if use_unicode:
|
|
if field_type == FIELD_TYPE.JSON:
|
|
# When SELECT from JSON column: charset = binary
|
|
# When SELECT CAST(... AS JSON): charset = connection
|
|
# encoding
|
|
# This behavior is different from TEXT / BLOB.
|
|
# We should decode result by connection encoding
|
|
# regardless charsetnr.
|
|
# See https://github.com/PyMySQL/PyMySQL/issues/488
|
|
encoding = conn_encoding # SELECT CAST(... AS JSON)
|
|
elif field_type in TEXT_TYPES:
|
|
if field.charsetnr == 63: # binary
|
|
# TEXTs with charset=binary means BINARY types.
|
|
encoding = None
|
|
else:
|
|
encoding = conn_encoding
|
|
else:
|
|
# Integers, Dates and Times, and other basic data
|
|
# is encoded in ascii
|
|
encoding = 'ascii'
|
|
else:
|
|
encoding = None
|
|
converter = self.connection.decoders.get(field_type)
|
|
if converter is through:
|
|
converter = None
|
|
self.converters.append((encoding, converter))
|
|
|
|
eof_packet = await self.connection._read_packet()
|
|
assert eof_packet.is_eof_packet(), 'Protocol error, expecting EOF'
|
|
self.description = tuple(description)
|
|
|
|
|
|
class LoadLocalFile:
|
|
def __init__(self, filename, connection):
|
|
self.filename = filename
|
|
self.connection = connection
|
|
self._loop = connection.loop
|
|
self._file_object = None
|
|
self._executor = None # means use default executor
|
|
|
|
def _open_file(self):
|
|
|
|
def opener(filename):
|
|
try:
|
|
self._file_object = open(filename, 'rb')
|
|
except OSError as e:
|
|
msg = f"Can't find file '{filename}'"
|
|
raise OperationalError(1017, msg) from e
|
|
|
|
fut = self._loop.run_in_executor(self._executor, opener, self.filename)
|
|
return fut
|
|
|
|
def _file_read(self, chunk_size):
|
|
|
|
def freader(chunk_size):
|
|
try:
|
|
chunk = self._file_object.read(chunk_size)
|
|
|
|
if not chunk:
|
|
self._file_object.close()
|
|
self._file_object = None
|
|
|
|
except Exception as e:
|
|
self._file_object.close()
|
|
self._file_object = None
|
|
msg = f"Error reading file {self.filename}"
|
|
raise OperationalError(1024, msg) from e
|
|
return chunk
|
|
|
|
fut = self._loop.run_in_executor(self._executor, freader, chunk_size)
|
|
return fut
|
|
|
|
async def send_data(self):
|
|
"""Send data packets from the local file to the server"""
|
|
self.connection._ensure_alive()
|
|
conn = self.connection
|
|
|
|
try:
|
|
await self._open_file()
|
|
with self._file_object:
|
|
chunk_size = MAX_PACKET_LEN
|
|
while True:
|
|
chunk = await self._file_read(chunk_size)
|
|
if not chunk:
|
|
break
|
|
# TODO: consider drain data
|
|
conn.write_packet(chunk)
|
|
except asyncio.CancelledError:
|
|
self.connection._close_on_cancel()
|
|
raise
|
|
finally:
|
|
# send the empty packet to signify we are done sending data
|
|
conn.write_packet(b"")
|