243 lines
8.0 KiB
Python
243 lines
8.0 KiB
Python
"""This module defines specific functions for MySQL dialect."""
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.dialects.mysql.base import ischema_names as _mysql_ischema_names
|
|
from sqlalchemy.ext.compiler import compiles
|
|
from sqlalchemy.sql.sqltypes import NullType
|
|
|
|
from geoalchemy2 import functions
|
|
from geoalchemy2.admin.dialects.common import _check_spatial_type
|
|
from geoalchemy2.admin.dialects.common import _spatial_idx_name
|
|
from geoalchemy2.admin.dialects.common import compile_bin_literal
|
|
from geoalchemy2.admin.dialects.common import setup_create_drop
|
|
from geoalchemy2.types import Geography
|
|
from geoalchemy2.types import Geometry
|
|
|
|
# Register Geometry, Geography and Raster to SQLAlchemy's reflection subsystems.
|
|
_mysql_ischema_names["geometry"] = Geometry
|
|
_mysql_ischema_names["point"] = Geometry
|
|
_mysql_ischema_names["linestring"] = Geometry
|
|
_mysql_ischema_names["polygon"] = Geometry
|
|
_mysql_ischema_names["multipoint"] = Geometry
|
|
_mysql_ischema_names["multilinestring"] = Geometry
|
|
_mysql_ischema_names["multipolygon"] = Geometry
|
|
_mysql_ischema_names["geometrycollection"] = Geometry
|
|
|
|
|
|
_POSSIBLE_TYPES = [
|
|
"geometry",
|
|
"point",
|
|
"linestring",
|
|
"polygon",
|
|
"multipoint",
|
|
"multilinestring",
|
|
"multipolygon",
|
|
"geometrycollection",
|
|
]
|
|
|
|
|
|
def reflect_geometry_column(inspector, table, column_info):
|
|
"""Reflect a column of type Geometry with Postgresql dialect."""
|
|
if not isinstance(column_info.get("type"), (Geometry, NullType)):
|
|
return
|
|
|
|
column_name = column_info.get("name")
|
|
schema = table.schema or inspector.default_schema_name
|
|
|
|
if inspector.dialect.name == "mariadb":
|
|
select_srid = "-1, "
|
|
else:
|
|
select_srid = "SRS_ID, "
|
|
|
|
# Check geometry type, SRID and if the column is nullable
|
|
geometry_type_query = """SELECT DATA_TYPE, {}IS_NULLABLE
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
|
|
select_srid, table.name, column_name
|
|
)
|
|
if schema is not None:
|
|
geometry_type_query += """ and table_schema = '{}'""".format(schema)
|
|
geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one()
|
|
is_nullable = str(nullable_str).lower() == "yes"
|
|
|
|
if geometry_type not in _POSSIBLE_TYPES:
|
|
return # pragma: no cover
|
|
|
|
# Check if the column has spatial index
|
|
has_index_query = """SELECT DISTINCT
|
|
INDEX_TYPE
|
|
FROM INFORMATION_SCHEMA.STATISTICS
|
|
WHERE TABLE_NAME = '{}' and COLUMN_NAME = '{}'""".format(
|
|
table.name, column_name
|
|
)
|
|
if schema is not None:
|
|
has_index_query += """ and TABLE_SCHEMA = '{}'""".format(schema)
|
|
spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar()
|
|
spatial_index = str(spatial_index_res).lower() == "spatial"
|
|
|
|
# Set attributes
|
|
column_info["type"] = Geometry(
|
|
geometry_type=geometry_type.upper(),
|
|
srid=srid,
|
|
spatial_index=spatial_index,
|
|
nullable=is_nullable,
|
|
_spatial_index_reflected=True,
|
|
)
|
|
|
|
|
|
def before_cursor_execute(
|
|
conn, cursor, statement, parameters, context, executemany, convert=True
|
|
): # noqa: D417
|
|
"""Event handler to cast the parameters properly.
|
|
|
|
Args:
|
|
convert (bool): Trigger the conversion.
|
|
"""
|
|
if convert:
|
|
if isinstance(parameters, (tuple, list)):
|
|
parameters = tuple(x.tobytes() if isinstance(x, memoryview) else x for x in parameters)
|
|
elif isinstance(parameters, dict):
|
|
for k in parameters:
|
|
if isinstance(parameters[k], memoryview):
|
|
parameters[k] = parameters[k].tobytes()
|
|
|
|
return statement, parameters
|
|
|
|
|
|
def before_create(table, bind, **kw):
|
|
"""Handle spatial indexes during the before_create event."""
|
|
dialect, gis_cols, regular_cols = setup_create_drop(table, bind)
|
|
|
|
# Remove the spatial indexes from the table metadata because they should not be
|
|
# created during the table.create() step since the associated columns do not exist
|
|
# at this time.
|
|
table.info["_after_create_indexes"] = []
|
|
current_indexes = set(table.indexes)
|
|
for idx in current_indexes:
|
|
for col in table.info["_saved_columns"]:
|
|
if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values():
|
|
table.indexes.remove(idx)
|
|
if idx.name != _spatial_idx_name(table.name, col.name) or not getattr(
|
|
col.type, "spatial_index", False
|
|
):
|
|
table.info["_after_create_indexes"].append(idx)
|
|
|
|
table.columns = table.info.pop("_saved_columns")
|
|
|
|
|
|
def after_create(table, bind, **kw):
|
|
"""Handle spatial indexes during the after_create event."""
|
|
# Restore original column list including managed Geometry columns
|
|
dialect = bind.dialect
|
|
|
|
# table.columns = table.info.pop("_saved_columns")
|
|
|
|
for col in table.columns:
|
|
# Add spatial indices for the Geometry and Geography columns
|
|
if (
|
|
_check_spatial_type(col.type, (Geometry, Geography), dialect)
|
|
and col.type.spatial_index is True
|
|
and col.computed is None
|
|
):
|
|
# If the index does not exist, define it and create it
|
|
if not [i for i in table.indexes if col in i.columns.values()]:
|
|
sql = "ALTER TABLE {} ADD SPATIAL INDEX({});".format(table.name, col.name)
|
|
q = text(sql)
|
|
bind.execute(q)
|
|
|
|
for idx in table.info.pop("_after_create_indexes"):
|
|
table.indexes.add(idx)
|
|
|
|
|
|
def before_drop(table, bind, **kw):
|
|
return
|
|
|
|
|
|
def after_drop(table, bind, **kw):
|
|
return
|
|
|
|
|
|
_MYSQL_FUNCTIONS = {"ST_AsEWKB": "ST_AsBinary", "ST_SetSRID": "ST_SRID"}
|
|
|
|
|
|
def _compiles_mysql(cls, fn):
|
|
def _compile_mysql(element, compiler, **kw):
|
|
return "{}({})".format(fn, compiler.process(element.clauses, **kw))
|
|
|
|
compiles(getattr(functions, cls), "mysql")(_compile_mysql)
|
|
|
|
|
|
def register_mysql_mapping(mapping):
|
|
"""Register compilation mappings for the given functions.
|
|
|
|
Args:
|
|
mapping: Should have the following form::
|
|
|
|
{
|
|
"function_name_1": "mysql_function_name_1",
|
|
"function_name_2": "mysql_function_name_2",
|
|
...
|
|
}
|
|
"""
|
|
for cls, fn in mapping.items():
|
|
_compiles_mysql(cls, fn)
|
|
|
|
|
|
register_mysql_mapping(_MYSQL_FUNCTIONS)
|
|
|
|
|
|
def _compile_GeomFromText_MySql(element, compiler, **kw):
|
|
identifier = "ST_GeomFromText"
|
|
compiled = compiler.process(element.clauses, **kw)
|
|
srid = element.type.srid
|
|
|
|
if srid > 0:
|
|
return "{}({}, {})".format(identifier, compiled, srid)
|
|
else:
|
|
return "{}({})".format(identifier, compiled)
|
|
|
|
|
|
def _compile_GeomFromWKB_MySql(element, compiler, **kw):
|
|
# Store the SRID
|
|
clauses = list(element.clauses)
|
|
try:
|
|
srid = clauses[1].value
|
|
except (IndexError, TypeError, ValueError):
|
|
srid = element.type.srid
|
|
|
|
if kw.get("literal_binds", False):
|
|
wkb_clause = compile_bin_literal(clauses[0])
|
|
prefix = "unhex("
|
|
suffix = ")"
|
|
else:
|
|
wkb_clause = clauses[0]
|
|
prefix = ""
|
|
suffix = ""
|
|
|
|
compiled = compiler.process(wkb_clause, **kw)
|
|
|
|
if srid > 0:
|
|
return "{}({}{}{}, {})".format(element.identifier, prefix, compiled, suffix, srid)
|
|
else:
|
|
return "{}({}{}{})".format(element.identifier, prefix, compiled, suffix)
|
|
|
|
|
|
@compiles(functions.ST_GeomFromText, "mysql") # type: ignore
|
|
def _MySQL_ST_GeomFromText(element, compiler, **kw):
|
|
return _compile_GeomFromText_MySql(element, compiler, **kw)
|
|
|
|
|
|
@compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore
|
|
def _MySQL_ST_GeomFromEWKT(element, compiler, **kw):
|
|
return _compile_GeomFromText_MySql(element, compiler, **kw)
|
|
|
|
|
|
@compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore
|
|
def _MySQL_ST_GeomFromWKB(element, compiler, **kw):
|
|
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
|
|
|
|
|
|
@compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore
|
|
def _MySQL_ST_GeomFromEWKB(element, compiler, **kw):
|
|
return _compile_GeomFromWKB_MySql(element, compiler, **kw)
|