426 lines
14 KiB
Python
426 lines
14 KiB
Python
# ported from:
|
|
# https://github.com/aio-libs/aiopg/blob/master/aiopg/sa/connection.py
|
|
import weakref
|
|
|
|
from sqlalchemy.sql import ClauseElement
|
|
from sqlalchemy.sql.dml import UpdateBase
|
|
from sqlalchemy.sql.ddl import DDLElement
|
|
|
|
from . import exc
|
|
from .result import create_result_proxy
|
|
from .transaction import (RootTransaction, Transaction,
|
|
NestedTransaction, TwoPhaseTransaction)
|
|
from ..utils import _TransactionContextManager, _SAConnectionContextManager
|
|
|
|
|
|
def noop(k):
|
|
return k
|
|
|
|
|
|
class SAConnection:
|
|
|
|
def __init__(self, connection, engine, compiled_cache=None):
|
|
self._connection = connection
|
|
self._transaction = None
|
|
self._savepoint_seq = 0
|
|
self._weak_results = weakref.WeakSet()
|
|
self._engine = engine
|
|
self._dialect = engine.dialect
|
|
self._compiled_cache = compiled_cache
|
|
|
|
def execute(self, query, *multiparams, **params):
|
|
"""Executes a SQL query with optional parameters.
|
|
|
|
query - a SQL query string or any sqlalchemy expression.
|
|
|
|
*multiparams/**params - represent bound parameter values to be
|
|
used in the execution. Typically, the format is a dictionary
|
|
passed to *multiparams:
|
|
|
|
await conn.execute(
|
|
table.insert(),
|
|
{"id":1, "value":"v1"},
|
|
)
|
|
|
|
...or individual key/values interpreted by **params::
|
|
|
|
await conn.execute(
|
|
table.insert(), id=1, value="v1"
|
|
)
|
|
|
|
In the case that a plain SQL string is passed, a tuple or
|
|
individual values in *multiparams may be passed::
|
|
|
|
await conn.execute(
|
|
"INSERT INTO table (id, value) VALUES (%d, %s)",
|
|
(1, "v1")
|
|
)
|
|
|
|
await conn.execute(
|
|
"INSERT INTO table (id, value) VALUES (%s, %s)",
|
|
1, "v1"
|
|
)
|
|
|
|
Returns ResultProxy instance with results of SQL query
|
|
execution.
|
|
|
|
"""
|
|
coro = self._execute(query, *multiparams, **params)
|
|
return _SAConnectionContextManager(coro)
|
|
|
|
def _base_params(self, query, dp, compiled, is_update):
|
|
"""
|
|
handle params
|
|
"""
|
|
if dp and isinstance(dp, (list, tuple)):
|
|
if is_update:
|
|
dp = {c.key: pval for c, pval in zip(query.table.c, dp)}
|
|
else:
|
|
raise exc.ArgumentError(
|
|
"Don't mix sqlalchemy SELECT "
|
|
"clause with positional "
|
|
"parameters"
|
|
)
|
|
compiled_params = compiled.construct_params(dp)
|
|
processors = compiled._bind_processors
|
|
params = [{
|
|
key: processors.get(key, noop)(compiled_params[key])
|
|
for key in compiled_params
|
|
}]
|
|
post_processed_params = self._dialect.execute_sequence_format(params)
|
|
return post_processed_params[0]
|
|
|
|
async def _executemany(self, query, dps, cursor):
|
|
"""
|
|
executemany
|
|
"""
|
|
result_map = None
|
|
if isinstance(query, str):
|
|
await cursor.executemany(query, dps)
|
|
elif isinstance(query, DDLElement):
|
|
raise exc.ArgumentError(
|
|
"Don't mix sqlalchemy DDL clause "
|
|
"and execution with parameters"
|
|
)
|
|
elif isinstance(query, ClauseElement):
|
|
compiled = query.compile(dialect=self._dialect)
|
|
params = []
|
|
is_update = isinstance(query, UpdateBase)
|
|
for dp in dps:
|
|
params.append(
|
|
self._base_params(
|
|
query,
|
|
dp,
|
|
compiled,
|
|
is_update,
|
|
)
|
|
)
|
|
await cursor.executemany(str(compiled), params)
|
|
result_map = compiled._result_columns
|
|
else:
|
|
raise exc.ArgumentError(
|
|
"sql statement should be str or "
|
|
"SQLAlchemy data "
|
|
"selection/modification clause"
|
|
)
|
|
ret = await create_result_proxy(
|
|
self,
|
|
cursor,
|
|
self._dialect,
|
|
result_map
|
|
)
|
|
self._weak_results.add(ret)
|
|
return ret
|
|
|
|
async def _execute(self, query, *multiparams, **params):
|
|
cursor = await self._connection.cursor()
|
|
dp = _distill_params(multiparams, params)
|
|
if len(dp) > 1:
|
|
return await self._executemany(query, dp, cursor)
|
|
elif dp:
|
|
dp = dp[0]
|
|
|
|
result_map = None
|
|
if isinstance(query, str):
|
|
await cursor.execute(query, dp or None)
|
|
elif isinstance(query, ClauseElement):
|
|
if self._compiled_cache is not None:
|
|
key = query
|
|
compiled = self._compiled_cache.get(key)
|
|
if not compiled:
|
|
compiled = query.compile(dialect=self._dialect)
|
|
if dp and dp.keys() == compiled.params.keys() \
|
|
or not (dp or compiled.params):
|
|
# we only want queries with bound params in cache
|
|
self._compiled_cache[key] = compiled
|
|
else:
|
|
compiled = query.compile(dialect=self._dialect)
|
|
|
|
if not isinstance(query, DDLElement):
|
|
post_processed_params = self._base_params(
|
|
query,
|
|
dp,
|
|
compiled,
|
|
isinstance(query, UpdateBase)
|
|
)
|
|
result_map = compiled._result_columns
|
|
else:
|
|
if dp:
|
|
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
|
|
"and execution with parameters")
|
|
post_processed_params = compiled.construct_params()
|
|
result_map = None
|
|
await cursor.execute(str(compiled), post_processed_params)
|
|
else:
|
|
raise exc.ArgumentError("sql statement should be str or "
|
|
"SQLAlchemy data "
|
|
"selection/modification clause")
|
|
|
|
ret = await create_result_proxy(
|
|
self, cursor, self._dialect, result_map
|
|
)
|
|
self._weak_results.add(ret)
|
|
return ret
|
|
|
|
async def scalar(self, query, *multiparams, **params):
|
|
"""Executes a SQL query and returns a scalar value."""
|
|
res = await self.execute(query, *multiparams, **params)
|
|
return (await res.scalar())
|
|
|
|
@property
|
|
def closed(self):
|
|
"""The readonly property that returns True if connections is closed."""
|
|
return self._connection is None or self._connection.closed
|
|
|
|
@property
|
|
def connection(self):
|
|
return self._connection
|
|
|
|
def begin(self):
|
|
"""Begin a transaction and return a transaction handle.
|
|
|
|
The returned object is an instance of Transaction. This
|
|
object represents the "scope" of the transaction, which
|
|
completes when either the .rollback or .commit method is
|
|
called.
|
|
|
|
Nested calls to .begin on the same SAConnection instance will
|
|
return new Transaction objects that represent an emulated
|
|
transaction within the scope of the enclosing transaction,
|
|
that is::
|
|
|
|
trans = await conn.begin() # outermost transaction
|
|
trans2 = await conn.begin() # "nested"
|
|
await trans2.commit() # does nothing
|
|
await trans.commit() # actually commits
|
|
|
|
Calls to .commit only have an effect when invoked via the
|
|
outermost Transaction object, though the .rollback method of
|
|
any of the Transaction objects will roll back the transaction.
|
|
|
|
See also:
|
|
.begin_nested - use a SAVEPOINT
|
|
.begin_twophase - use a two phase/XA transaction
|
|
|
|
"""
|
|
coro = self._begin()
|
|
return _TransactionContextManager(coro)
|
|
|
|
async def _begin(self):
|
|
if self._transaction is None:
|
|
self._transaction = RootTransaction(self)
|
|
await self._begin_impl()
|
|
return self._transaction
|
|
else:
|
|
return Transaction(self, self._transaction)
|
|
|
|
async def _begin_impl(self):
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('BEGIN')
|
|
finally:
|
|
await cur.close()
|
|
|
|
async def _commit_impl(self):
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('COMMIT')
|
|
finally:
|
|
await cur.close()
|
|
self._transaction = None
|
|
|
|
async def _rollback_impl(self):
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('ROLLBACK')
|
|
finally:
|
|
await cur.close()
|
|
self._transaction = None
|
|
|
|
async def begin_nested(self):
|
|
"""Begin a nested transaction and return a transaction handle.
|
|
|
|
The returned object is an instance of :class:`.NestedTransaction`.
|
|
|
|
Nested transactions require SAVEPOINT support in the
|
|
underlying database. Any transaction in the hierarchy may
|
|
.commit() and .rollback(), however the outermost transaction
|
|
still controls the overall .commit() or .rollback() of the
|
|
transaction of a whole.
|
|
"""
|
|
if self._transaction is None:
|
|
self._transaction = RootTransaction(self)
|
|
await self._begin_impl()
|
|
else:
|
|
self._transaction = NestedTransaction(self, self._transaction)
|
|
self._transaction._savepoint = await self._savepoint_impl()
|
|
return self._transaction
|
|
|
|
async def _savepoint_impl(self, name=None):
|
|
self._savepoint_seq += 1
|
|
name = 'aiomysql_sa_savepoint_%s' % self._savepoint_seq
|
|
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('SAVEPOINT ' + name)
|
|
return name
|
|
finally:
|
|
await cur.close()
|
|
|
|
async def _rollback_to_savepoint_impl(self, name, parent):
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('ROLLBACK TO SAVEPOINT ' + name)
|
|
finally:
|
|
await cur.close()
|
|
self._transaction = parent
|
|
|
|
async def _release_savepoint_impl(self, name, parent):
|
|
cur = await self._connection.cursor()
|
|
try:
|
|
await cur.execute('RELEASE SAVEPOINT ' + name)
|
|
finally:
|
|
await cur.close()
|
|
self._transaction = parent
|
|
|
|
async def begin_twophase(self, xid=None):
|
|
"""Begin a two-phase or XA transaction and return a transaction
|
|
handle.
|
|
|
|
The returned object is an instance of
|
|
TwoPhaseTransaction, which in addition to the
|
|
methods provided by Transaction, also provides a
|
|
TwoPhaseTransaction.prepare() method.
|
|
|
|
xid - the two phase transaction id. If not supplied, a
|
|
random id will be generated.
|
|
"""
|
|
|
|
if self._transaction is not None:
|
|
raise exc.InvalidRequestError(
|
|
"Cannot start a two phase transaction when a transaction "
|
|
"is already in progress.")
|
|
if xid is None:
|
|
xid = self._dialect.create_xid()
|
|
self._transaction = TwoPhaseTransaction(self, xid)
|
|
await self.execute("XA START %s", xid)
|
|
return self._transaction
|
|
|
|
async def _prepare_twophase_impl(self, xid):
|
|
await self.execute("XA END '%s'" % xid)
|
|
await self.execute("XA PREPARE '%s'" % xid)
|
|
|
|
async def recover_twophase(self):
|
|
"""Return a list of prepared twophase transaction ids."""
|
|
result = await self.execute("XA RECOVER;")
|
|
return [row[0] for row in result]
|
|
|
|
async def rollback_prepared(self, xid, *, is_prepared=True):
|
|
"""Rollback prepared twophase transaction."""
|
|
if not is_prepared:
|
|
await self.execute("XA END '%s'" % xid)
|
|
await self.execute("XA ROLLBACK '%s'" % xid)
|
|
|
|
async def commit_prepared(self, xid, *, is_prepared=True):
|
|
"""Commit prepared twophase transaction."""
|
|
if not is_prepared:
|
|
await self.execute("XA END '%s'" % xid)
|
|
await self.execute("XA COMMIT '%s'" % xid)
|
|
|
|
@property
|
|
def in_transaction(self):
|
|
"""Return True if a transaction is in progress."""
|
|
return self._transaction is not None and self._transaction.is_active
|
|
|
|
async def close(self):
|
|
"""Close this SAConnection.
|
|
|
|
This results in a release of the underlying database
|
|
resources, that is, the underlying connection referenced
|
|
internally. The underlying connection is typically restored
|
|
back to the connection-holding Pool referenced by the Engine
|
|
that produced this SAConnection. Any transactional state
|
|
present on the underlying connection is also unconditionally
|
|
released via calling Transaction.rollback() method.
|
|
|
|
After .close() is called, the SAConnection is permanently in a
|
|
closed state, and will allow no further operations.
|
|
"""
|
|
if self._connection is None:
|
|
return
|
|
|
|
if self._transaction is not None:
|
|
await self._transaction.rollback()
|
|
self._transaction = None
|
|
# don't close underlying connection, it can be reused by pool
|
|
# conn.close()
|
|
self._engine.release(self)
|
|
self._connection = None
|
|
self._engine = None
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
await self.close()
|
|
|
|
|
|
def _distill_params(multiparams, params):
|
|
"""Given arguments from the calling form *multiparams, **params,
|
|
return a list of bind parameter structures, usually a list of
|
|
dictionaries.
|
|
|
|
In the case of 'raw' execution which accepts positional parameters,
|
|
it may be a list of tuples or lists.
|
|
|
|
"""
|
|
|
|
if not multiparams:
|
|
if params:
|
|
return [params]
|
|
else:
|
|
return []
|
|
elif len(multiparams) == 1:
|
|
zero = multiparams[0]
|
|
if isinstance(zero, (list, tuple)):
|
|
if not zero or hasattr(zero[0], '__iter__') and \
|
|
not hasattr(zero[0], 'strip'):
|
|
# execute(stmt, [{}, {}, {}, ...])
|
|
# execute(stmt, [(), (), (), ...])
|
|
return zero
|
|
else:
|
|
# execute(stmt, ("value", "value"))
|
|
return [zero]
|
|
elif hasattr(zero, 'keys'):
|
|
# execute(stmt, {"key":"value"})
|
|
return [zero]
|
|
else:
|
|
# execute(stmt, "value")
|
|
return [[zero]]
|
|
else:
|
|
if (hasattr(multiparams[0], '__iter__') and
|
|
not hasattr(multiparams[0], 'strip')):
|
|
return multiparams
|
|
else:
|
|
return [multiparams]
|