import re import json import warnings import contextlib from pymysql.err import ( Warning, Error, InterfaceError, DataError, DatabaseError, OperationalError, IntegrityError, InternalError, NotSupportedError, ProgrammingError) from .log import logger from .connection import FIELD_TYPE # https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 #: Regular expression for :meth:`Cursor.executemany`. #: executemany only supports simple bulk insert. #: You can use it to load large dataset. RE_INSERT_VALUES = re.compile( r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", re.IGNORECASE | re.DOTALL) class Cursor: """Cursor is used to interact with the database.""" #: Max statement size which :meth:`executemany` generates. #: #: Max size of allowed statement is max_allowed_packet - # packet_header_size. #: Default value of max_allowed_packet is 1048576. max_stmt_length = 1024000 def __init__(self, connection, echo=False): """Do not create an instance of a Cursor yourself. Call connections.Connection.cursor(). """ self._connection = connection self._loop = self._connection.loop self._description = None self._rownumber = 0 self._rowcount = -1 self._arraysize = 1 self._executed = None self._result = None self._rows = None self._lastrowid = None self._echo = echo @property def connection(self): """This read-only attribute return a reference to the Connection object on which the cursor was created.""" return self._connection @property def description(self): """This read-only attribute is a sequence of 7-item sequences. Each of these sequences is a collections.namedtuple containing information describing one result column: 0. name: the name of the column returned. 1. type_code: the type of the column. 2. display_size: the actual length of the column in bytes. 3. internal_size: the size in bytes of the column associated to this column on the server. 4. precision: total number of significant digits in columns of type NUMERIC. None for other types. 5. scale: count of decimal digits in the fractional part in columns of type NUMERIC. None for other types. 6. null_ok: always None as not easy to retrieve from the libpq. This attribute will be None for operations that do not return rows or if the cursor has not had an operation invoked via the execute() method yet. """ return self._description @property def rowcount(self): """Returns the number of rows that has been produced of affected. This read-only attribute specifies the number of rows that the last :meth:`execute` produced (for Data Query Language statements like SELECT) or affected (for Data Manipulation Language statements like UPDATE or INSERT). The attribute is -1 in case no .execute() has been performed on the cursor or the row count of the last operation if it can't be determined by the interface. """ return self._rowcount @property def rownumber(self): """Row index. This read-only attribute provides the current 0-based index of the cursor in the result set or ``None`` if the index cannot be determined. """ return self._rownumber @property def arraysize(self): """How many rows will be returned by fetchmany() call. This read/write attribute specifies the number of rows to fetch at a time with fetchmany(). It defaults to 1 meaning to fetch a single row at a time. """ return self._arraysize @arraysize.setter def arraysize(self, val): """How many rows will be returned by fetchmany() call. This read/write attribute specifies the number of rows to fetch at a time with fetchmany(). It defaults to 1 meaning to fetch a single row at a time. """ self._arraysize = val @property def lastrowid(self): """This read-only property returns the value generated for an AUTO_INCREMENT column by the previous INSERT or UPDATE statement or None when there is no such value available. For example, if you perform an INSERT into a table that contains an AUTO_INCREMENT column, lastrowid returns the AUTO_INCREMENT value for the new row. """ return self._lastrowid @property def echo(self): """Return echo mode status.""" return self._echo @property def closed(self): """The readonly property that returns ``True`` if connections was detached from current cursor """ return True if not self._connection else False async def close(self): """Closing a cursor just exhausts all remaining data.""" conn = self._connection if conn is None: return try: while (await self.nextset()): pass finally: self._connection = None def _get_db(self): if not self._connection: raise ProgrammingError("Cursor closed") return self._connection def _check_executed(self): if not self._executed: raise ProgrammingError("execute() first") def _conv_row(self, row): return row def setinputsizes(self, *args): """Does nothing, required by DB API.""" def setoutputsizes(self, *args): """Does nothing, required by DB API.""" async def nextset(self): """Get the next query set""" conn = self._get_db() current_result = self._result if current_result is None or current_result is not conn._result: return if not current_result.has_next: return self._result = None self._clear_result() await conn.next_result() await self._do_get_result() return True def _escape_args(self, args, conn): if isinstance(args, (tuple, list)): return tuple(conn.escape(arg) for arg in args) elif isinstance(args, dict): return {key: conn.escape(val) for (key, val) in args.items()} else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error return conn.escape(args) def mogrify(self, query, args=None): """ Returns the exact string that is sent to the database by calling the execute() method. This method follows the extension to the DB API 2.0 followed by Psycopg. :param query: ``str`` sql statement :param args: ``tuple`` or ``list`` of arguments for sql query """ conn = self._get_db() if args is not None: query = query % self._escape_args(args, conn) return query async def execute(self, query, args=None): """Executes the given operation Executes the given operation substituting any markers with the given parameters. For example, getting all rows where id is 5: cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,)) :param query: ``str`` sql statement :param args: ``tuple`` or ``list`` of arguments for sql query :returns: ``int``, number of rows that has been produced of affected """ conn = self._get_db() while (await self.nextset()): pass if args is not None: query = query % self._escape_args(args, conn) await self._query(query) self._executed = query if self._echo: logger.info(query) logger.info("%r", args) return self._rowcount async def executemany(self, query, args): """Execute the given operation multiple times The executemany() method will execute the operation iterating over the list of parameters in seq_params. Example: Inserting 3 new employees and their phone number data = [ ('Jane','555-001'), ('Joe', '555-001'), ('John', '555-003') ] stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')" await cursor.executemany(stmt, data) INSERT or REPLACE statements are optimized by batching the data, that is using the MySQL multiple rows syntax. :param query: `str`, sql statement :param args: ``tuple`` or ``list`` of arguments for sql query """ if not args: return if self._echo: logger.info("CALL %s", query) logger.info("%r", args) m = RE_INSERT_VALUES.match(query) if m: q_prefix = m.group(1) % () q_values = m.group(2).rstrip() q_postfix = m.group(3) or '' assert q_values[0] == '(' and q_values[-1] == ')' return (await self._do_execute_many( q_prefix, q_values, q_postfix, args, self.max_stmt_length, self._get_db().encoding)) else: rows = 0 for arg in args: await self.execute(query, arg) rows += self._rowcount self._rowcount = rows return self._rowcount async def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): conn = self._get_db() escape = self._escape_args if isinstance(prefix, str): prefix = prefix.encode(encoding) if isinstance(postfix, str): postfix = postfix.encode(encoding) sql = bytearray(prefix) args = iter(args) v = values % escape(next(args), conn) if isinstance(v, str): v = v.encode(encoding, 'surrogateescape') sql += v rows = 0 for arg in args: v = values % escape(arg, conn) if isinstance(v, str): v = v.encode(encoding, 'surrogateescape') if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: r = await self.execute(sql + postfix) rows += r sql = bytearray(prefix) else: sql += b',' sql += v r = await self.execute(sql + postfix) rows += r self._rowcount = rows return rows async def callproc(self, procname, args=()): """Execute stored procedure procname with args Compatibility warning: PEP-249 specifies that any modified parameters must be returned. This is currently impossible as they are only available by storing them in a server variable and then retrieved by a query. Since stored procedures return zero or more result sets, there is no reliable way to get at OUT or INOUT parameters via callproc. The server variables are named @_procname_n, where procname is the parameter above and n is the position of the parameter (from zero). Once all result sets generated by the procedure have been fetched, you can issue a SELECT @_procname_0, ... query using .execute() to get any OUT or INOUT values. Compatibility warning: The act of calling a stored procedure itself creates an empty result set. This appears after any result sets generated by the procedure. This is non-standard behavior with respect to the DB-API. Be sure to use nextset() to advance through all result sets; otherwise you may get disconnected. :param procname: ``str``, name of procedure to execute on server :param args: `sequence of parameters to use with procedure :returns: the original args. """ conn = self._get_db() if self._echo: logger.info("CALL %s", procname) logger.info("%r", args) for index, arg in enumerate(args): q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) await self._query(q) await self.nextset() _args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args))) q = f"CALL {procname}({_args})" await self._query(q) self._executed = q return args def fetchone(self): """Fetch the next row """ self._check_executed() fut = self._loop.create_future() if self._rows is None or self._rownumber >= len(self._rows): fut.set_result(None) return fut result = self._rows[self._rownumber] self._rownumber += 1 fut = self._loop.create_future() fut.set_result(result) return fut def fetchmany(self, size=None): """Returns the next set of rows of a query result, returning a list of tuples. When no more rows are available, it returns an empty list. The number of rows returned can be specified using the size argument, which defaults to one :param size: ``int`` number of rows to return :returns: ``list`` of fetched rows """ self._check_executed() fut = self._loop.create_future() if self._rows is None: fut.set_result([]) return fut end = self._rownumber + (size or self._arraysize) result = self._rows[self._rownumber:end] self._rownumber = min(end, len(self._rows)) fut.set_result(result) return fut def fetchall(self): """Returns all rows of a query result set :returns: ``list`` of fetched rows """ self._check_executed() fut = self._loop.create_future() if self._rows is None: fut.set_result([]) return fut if self._rownumber: result = self._rows[self._rownumber:] else: result = self._rows self._rownumber = len(self._rows) fut.set_result(result) return fut def scroll(self, value, mode='relative'): """Scroll the cursor in the result set to a new position according to mode. If mode is relative (default), value is taken as offset to the current position in the result set, if set to absolute, value states an absolute target position. An IndexError should be raised in case a scroll operation would leave the result set. In this case, the cursor position is left undefined (ideal would be to not move the cursor at all). :param int value: move cursor to next position according to mode. :param str mode: scroll mode, possible modes: `relative` and `absolute` """ self._check_executed() if mode == 'relative': r = self._rownumber + value elif mode == 'absolute': r = value else: raise ProgrammingError("unknown scroll mode %s" % mode) if not (0 <= r < len(self._rows)): raise IndexError("out of range") self._rownumber = r fut = self._loop.create_future() fut.set_result(None) return fut async def _query(self, q): conn = self._get_db() self._last_executed = q self._clear_result() await conn.query(q) await self._do_get_result() def _clear_result(self): self._rownumber = 0 self._result = None self._rowcount = 0 self._description = None self._lastrowid = None self._rows = None async def _do_get_result(self): conn = self._get_db() self._rownumber = 0 self._result = result = conn._result self._rowcount = result.affected_rows self._description = result.description self._lastrowid = result.insert_id self._rows = result.rows if result.warning_count > 0: await self._show_warnings(conn) async def _show_warnings(self, conn): if self._result and self._result.has_next: return ws = await conn.show_warnings() if ws is None: return for w in ws: msg = w[-1] warnings.warn(str(msg), Warning, 4) Warning = Warning Error = Error InterfaceError = InterfaceError DatabaseError = DatabaseError DataError = DataError OperationalError = OperationalError IntegrityError = IntegrityError InternalError = InternalError ProgrammingError = ProgrammingError NotSupportedError = NotSupportedError def __aiter__(self): return self async def __anext__(self): ret = await self.fetchone() if ret is not None: return ret else: raise StopAsyncIteration # noqa async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() return class _DeserializationCursorMixin: async def _do_get_result(self): await super()._do_get_result() if self._rows: self._rows = [self._deserialization_row(r) for r in self._rows] def _deserialization_row(self, row): if row is None: return None if isinstance(row, dict): dict_flag = True else: row = list(row) dict_flag = False for index, (name, field_type, *n) in enumerate(self._description): if field_type == FIELD_TYPE.JSON: point = name if dict_flag else index with contextlib.suppress(ValueError, TypeError): row[point] = json.loads(row[point]) if dict_flag: return row else: return tuple(row) def _conv_row(self, row): if row is None: return None row = super()._conv_row(row) return self._deserialization_row(row) class DeserializationCursor(_DeserializationCursorMixin, Cursor): """A cursor automatic deserialization of json type fields""" class _DictCursorMixin: # You can override this to use OrderedDict or other dict-like types. dict_type = dict async def _do_get_result(self): await super()._do_get_result() fields = [] if self._description: for f in self._result.fields: name = f.name if name in fields: name = f.table_name + '.' + name fields.append(name) self._fields = fields if fields and self._rows: self._rows = [self._conv_row(r) for r in self._rows] def _conv_row(self, row): if row is None: return None row = super()._conv_row(row) return self.dict_type(zip(self._fields, row)) class DictCursor(_DictCursorMixin, Cursor): """A cursor which returns results as a dictionary""" class SSCursor(Cursor): """Unbuffered Cursor, mainly useful for queries that return a lot of data, or for connections to remote servers over a slow network. Instead of copying every row of data into a buffer, this will fetch rows as needed. The upside of this, is the client uses much less memory, and rows are returned much faster when traveling over a slow network, or if the result set is very big. There are limitations, though. The MySQL protocol doesn't support returning the total number of rows, so the only way to tell how many rows there are is to iterate over every row returned. Also, it currently isn't possible to scroll backwards, as only the current row is held in memory. """ async def close(self): conn = self._connection if conn is None: return if self._result is not None and self._result is conn._result: await self._result._finish_unbuffered_query() try: while (await self.nextset()): pass finally: self._connection = None async def _query(self, q): conn = self._get_db() self._last_executed = q await conn.query(q, unbuffered=True) await self._do_get_result() return self._rowcount async def _read_next(self): """Read next row """ row = await self._result._read_rowdata_packet_unbuffered() row = self._conv_row(row) return row async def fetchone(self): """ Fetch next row """ self._check_executed() row = await self._read_next() if row is None: return self._rownumber += 1 return row async def fetchall(self): """Fetch all, as per MySQLdb. Pretty useless for large queries, as it is buffered. """ rows = [] while True: row = await self.fetchone() if row is None: break rows.append(row) return rows async def fetchmany(self, size=None): """Returns the next set of rows of a query result, returning a list of tuples. When no more rows are available, it returns an empty list. The number of rows returned can be specified using the size argument, which defaults to one :param size: ``int`` number of rows to return :returns: ``list`` of fetched rows """ self._check_executed() if size is None: size = self._arraysize rows = [] for i in range(size): row = await self._read_next() if row is None: break rows.append(row) self._rownumber += 1 return rows async def scroll(self, value, mode='relative'): """Scroll the cursor in the result set to a new position according to mode . Same as :meth:`Cursor.scroll`, but move cursor on server side one by one row. If you want to move 20 rows forward scroll will make 20 queries to move cursor. Currently only forward scrolling is supported. :param int value: move cursor to next position according to mode. :param str mode: scroll mode, possible modes: `relative` and `absolute` """ self._check_executed() if mode == 'relative': if value < 0: raise NotSupportedError("Backwards scrolling not supported " "by this cursor") for _ in range(value): await self._read_next() self._rownumber += value elif mode == 'absolute': if value < self._rownumber: raise NotSupportedError( "Backwards scrolling not supported by this cursor") end = value - self._rownumber for _ in range(end): await self._read_next() self._rownumber = value else: raise ProgrammingError("unknown scroll mode %s" % mode) class SSDictCursor(_DictCursorMixin, SSCursor): """An unbuffered cursor, which returns results as a dictionary """