504 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			504 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # testing/fixtures/sql.py
 | |
| # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
 | |
| # <see AUTHORS file>
 | |
| #
 | |
| # This module is part of SQLAlchemy and is released under
 | |
| # the MIT License: https://www.opensource.org/licenses/mit-license.php
 | |
| # mypy: ignore-errors
 | |
| from __future__ import annotations
 | |
| 
 | |
| import itertools
 | |
| import random
 | |
| import re
 | |
| import sys
 | |
| 
 | |
| import sqlalchemy as sa
 | |
| from .base import TestBase
 | |
| from .. import config
 | |
| from .. import mock
 | |
| from ..assertions import eq_
 | |
| from ..assertions import ne_
 | |
| from ..util import adict
 | |
| from ..util import drop_all_tables_from_metadata
 | |
| from ... import event
 | |
| from ... import util
 | |
| from ...schema import sort_tables_and_constraints
 | |
| from ...sql import visitors
 | |
| from ...sql.elements import ClauseElement
 | |
| 
 | |
| 
 | |
| class TablesTest(TestBase):
 | |
|     # 'once', None
 | |
|     run_setup_bind = "once"
 | |
| 
 | |
|     # 'once', 'each', None
 | |
|     run_define_tables = "once"
 | |
| 
 | |
|     # 'once', 'each', None
 | |
|     run_create_tables = "once"
 | |
| 
 | |
|     # 'once', 'each', None
 | |
|     run_inserts = "each"
 | |
| 
 | |
|     # 'each', None
 | |
|     run_deletes = "each"
 | |
| 
 | |
|     # 'once', None
 | |
|     run_dispose_bind = None
 | |
| 
 | |
|     bind = None
 | |
|     _tables_metadata = None
 | |
|     tables = None
 | |
|     other = None
 | |
|     sequences = None
 | |
| 
 | |
|     @config.fixture(autouse=True, scope="class")
 | |
|     def _setup_tables_test_class(self):
 | |
|         cls = self.__class__
 | |
|         cls._init_class()
 | |
| 
 | |
|         cls._setup_once_tables()
 | |
| 
 | |
|         cls._setup_once_inserts()
 | |
| 
 | |
|         yield
 | |
| 
 | |
|         cls._teardown_once_metadata_bind()
 | |
| 
 | |
|     @config.fixture(autouse=True, scope="function")
 | |
|     def _setup_tables_test_instance(self):
 | |
|         self._setup_each_tables()
 | |
|         self._setup_each_inserts()
 | |
| 
 | |
|         yield
 | |
| 
 | |
|         self._teardown_each_tables()
 | |
| 
 | |
|     @property
 | |
|     def tables_test_metadata(self):
 | |
|         return self._tables_metadata
 | |
| 
 | |
|     @classmethod
 | |
|     def _init_class(cls):
 | |
|         if cls.run_define_tables == "each":
 | |
|             if cls.run_create_tables == "once":
 | |
|                 cls.run_create_tables = "each"
 | |
|             assert cls.run_inserts in ("each", None)
 | |
| 
 | |
|         cls.other = adict()
 | |
|         cls.tables = adict()
 | |
|         cls.sequences = adict()
 | |
| 
 | |
|         cls.bind = cls.setup_bind()
 | |
|         cls._tables_metadata = sa.MetaData()
 | |
| 
 | |
|     @classmethod
 | |
|     def _setup_once_inserts(cls):
 | |
|         if cls.run_inserts == "once":
 | |
|             cls._load_fixtures()
 | |
|             with cls.bind.begin() as conn:
 | |
|                 cls.insert_data(conn)
 | |
| 
 | |
|     @classmethod
 | |
|     def _setup_once_tables(cls):
 | |
|         if cls.run_define_tables == "once":
 | |
|             cls.define_tables(cls._tables_metadata)
 | |
|             if cls.run_create_tables == "once":
 | |
|                 cls._tables_metadata.create_all(cls.bind)
 | |
|             cls.tables.update(cls._tables_metadata.tables)
 | |
|             cls.sequences.update(cls._tables_metadata._sequences)
 | |
| 
 | |
|     def _setup_each_tables(self):
 | |
|         if self.run_define_tables == "each":
 | |
|             self.define_tables(self._tables_metadata)
 | |
|             if self.run_create_tables == "each":
 | |
|                 self._tables_metadata.create_all(self.bind)
 | |
|             self.tables.update(self._tables_metadata.tables)
 | |
|             self.sequences.update(self._tables_metadata._sequences)
 | |
|         elif self.run_create_tables == "each":
 | |
|             self._tables_metadata.create_all(self.bind)
 | |
| 
 | |
|     def _setup_each_inserts(self):
 | |
|         if self.run_inserts == "each":
 | |
|             self._load_fixtures()
 | |
|             with self.bind.begin() as conn:
 | |
|                 self.insert_data(conn)
 | |
| 
 | |
|     def _teardown_each_tables(self):
 | |
|         if self.run_define_tables == "each":
 | |
|             self.tables.clear()
 | |
|             if self.run_create_tables == "each":
 | |
|                 drop_all_tables_from_metadata(self._tables_metadata, self.bind)
 | |
|             self._tables_metadata.clear()
 | |
|         elif self.run_create_tables == "each":
 | |
|             drop_all_tables_from_metadata(self._tables_metadata, self.bind)
 | |
| 
 | |
|         savepoints = getattr(config.requirements, "savepoints", False)
 | |
|         if savepoints:
 | |
|             savepoints = savepoints.enabled
 | |
| 
 | |
|         # no need to run deletes if tables are recreated on setup
 | |
|         if (
 | |
|             self.run_define_tables != "each"
 | |
|             and self.run_create_tables != "each"
 | |
|             and self.run_deletes == "each"
 | |
|         ):
 | |
|             with self.bind.begin() as conn:
 | |
|                 for table in reversed(
 | |
|                     [
 | |
|                         t
 | |
|                         for (t, fks) in sort_tables_and_constraints(
 | |
|                             self._tables_metadata.tables.values()
 | |
|                         )
 | |
|                         if t is not None
 | |
|                     ]
 | |
|                 ):
 | |
|                     try:
 | |
|                         if savepoints:
 | |
|                             with conn.begin_nested():
 | |
|                                 conn.execute(table.delete())
 | |
|                         else:
 | |
|                             conn.execute(table.delete())
 | |
|                     except sa.exc.DBAPIError as ex:
 | |
|                         print(
 | |
|                             ("Error emptying table %s: %r" % (table, ex)),
 | |
|                             file=sys.stderr,
 | |
|                         )
 | |
| 
 | |
|     @classmethod
 | |
|     def _teardown_once_metadata_bind(cls):
 | |
|         if cls.run_create_tables:
 | |
|             drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
 | |
| 
 | |
|         if cls.run_dispose_bind == "once":
 | |
|             cls.dispose_bind(cls.bind)
 | |
| 
 | |
|         cls._tables_metadata.bind = None
 | |
| 
 | |
|         if cls.run_setup_bind is not None:
 | |
|             cls.bind = None
 | |
| 
 | |
|     @classmethod
 | |
|     def setup_bind(cls):
 | |
|         return config.db
 | |
| 
 | |
|     @classmethod
 | |
|     def dispose_bind(cls, bind):
 | |
|         if hasattr(bind, "dispose"):
 | |
|             bind.dispose()
 | |
|         elif hasattr(bind, "close"):
 | |
|             bind.close()
 | |
| 
 | |
|     @classmethod
 | |
|     def define_tables(cls, metadata):
 | |
|         pass
 | |
| 
 | |
|     @classmethod
 | |
|     def fixtures(cls):
 | |
|         return {}
 | |
| 
 | |
|     @classmethod
 | |
|     def insert_data(cls, connection):
 | |
|         pass
 | |
| 
 | |
|     def sql_count_(self, count, fn):
 | |
|         self.assert_sql_count(self.bind, fn, count)
 | |
| 
 | |
|     def sql_eq_(self, callable_, statements):
 | |
|         self.assert_sql(self.bind, callable_, statements)
 | |
| 
 | |
|     @classmethod
 | |
|     def _load_fixtures(cls):
 | |
|         """Insert rows as represented by the fixtures() method."""
 | |
|         headers, rows = {}, {}
 | |
|         for table, data in cls.fixtures().items():
 | |
|             if len(data) < 2:
 | |
|                 continue
 | |
|             if isinstance(table, str):
 | |
|                 table = cls.tables[table]
 | |
|             headers[table] = data[0]
 | |
|             rows[table] = data[1:]
 | |
|         for table, fks in sort_tables_and_constraints(
 | |
|             cls._tables_metadata.tables.values()
 | |
|         ):
 | |
|             if table is None:
 | |
|                 continue
 | |
|             if table not in headers:
 | |
|                 continue
 | |
|             with cls.bind.begin() as conn:
 | |
|                 conn.execute(
 | |
|                     table.insert(),
 | |
|                     [
 | |
|                         dict(zip(headers[table], column_values))
 | |
|                         for column_values in rows[table]
 | |
|                     ],
 | |
|                 )
 | |
| 
 | |
| 
 | |
| class NoCache:
 | |
|     @config.fixture(autouse=True, scope="function")
 | |
|     def _disable_cache(self):
 | |
|         _cache = config.db._compiled_cache
 | |
|         config.db._compiled_cache = None
 | |
|         yield
 | |
|         config.db._compiled_cache = _cache
 | |
| 
 | |
| 
 | |
| class RemovesEvents:
 | |
|     @util.memoized_property
 | |
|     def _event_fns(self):
 | |
|         return set()
 | |
| 
 | |
|     def event_listen(self, target, name, fn, **kw):
 | |
|         self._event_fns.add((target, name, fn))
 | |
|         event.listen(target, name, fn, **kw)
 | |
| 
 | |
|     @config.fixture(autouse=True, scope="function")
 | |
|     def _remove_events(self):
 | |
|         yield
 | |
|         for key in self._event_fns:
 | |
|             event.remove(*key)
 | |
| 
 | |
| 
 | |
| class ComputedReflectionFixtureTest(TablesTest):
 | |
|     run_inserts = run_deletes = None
 | |
| 
 | |
|     __backend__ = True
 | |
|     __requires__ = ("computed_columns", "table_reflection")
 | |
| 
 | |
|     regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
 | |
| 
 | |
|     def normalize(self, text):
 | |
|         return self.regexp.sub("", text).lower()
 | |
| 
 | |
|     @classmethod
 | |
|     def define_tables(cls, metadata):
 | |
|         from ... import Integer
 | |
|         from ... import testing
 | |
|         from ...schema import Column
 | |
|         from ...schema import Computed
 | |
|         from ...schema import Table
 | |
| 
 | |
|         Table(
 | |
|             "computed_default_table",
 | |
|             metadata,
 | |
|             Column("id", Integer, primary_key=True),
 | |
|             Column("normal", Integer),
 | |
|             Column("computed_col", Integer, Computed("normal + 42")),
 | |
|             Column("with_default", Integer, server_default="42"),
 | |
|         )
 | |
| 
 | |
|         t = Table(
 | |
|             "computed_column_table",
 | |
|             metadata,
 | |
|             Column("id", Integer, primary_key=True),
 | |
|             Column("normal", Integer),
 | |
|             Column("computed_no_flag", Integer, Computed("normal + 42")),
 | |
|         )
 | |
| 
 | |
|         if testing.requires.schemas.enabled:
 | |
|             t2 = Table(
 | |
|                 "computed_column_table",
 | |
|                 metadata,
 | |
|                 Column("id", Integer, primary_key=True),
 | |
|                 Column("normal", Integer),
 | |
|                 Column("computed_no_flag", Integer, Computed("normal / 42")),
 | |
|                 schema=config.test_schema,
 | |
|             )
 | |
| 
 | |
|         if testing.requires.computed_columns_virtual.enabled:
 | |
|             t.append_column(
 | |
|                 Column(
 | |
|                     "computed_virtual",
 | |
|                     Integer,
 | |
|                     Computed("normal + 2", persisted=False),
 | |
|                 )
 | |
|             )
 | |
|             if testing.requires.schemas.enabled:
 | |
|                 t2.append_column(
 | |
|                     Column(
 | |
|                         "computed_virtual",
 | |
|                         Integer,
 | |
|                         Computed("normal / 2", persisted=False),
 | |
|                     )
 | |
|                 )
 | |
|         if testing.requires.computed_columns_stored.enabled:
 | |
|             t.append_column(
 | |
|                 Column(
 | |
|                     "computed_stored",
 | |
|                     Integer,
 | |
|                     Computed("normal - 42", persisted=True),
 | |
|                 )
 | |
|             )
 | |
|             if testing.requires.schemas.enabled:
 | |
|                 t2.append_column(
 | |
|                     Column(
 | |
|                         "computed_stored",
 | |
|                         Integer,
 | |
|                         Computed("normal * 42", persisted=True),
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
| 
 | |
| class CacheKeyFixture:
 | |
|     def _compare_equal(self, a, b, compare_values):
 | |
|         a_key = a._generate_cache_key()
 | |
|         b_key = b._generate_cache_key()
 | |
| 
 | |
|         if a_key is None:
 | |
|             assert a._annotations.get("nocache")
 | |
| 
 | |
|             assert b_key is None
 | |
|         else:
 | |
|             eq_(a_key.key, b_key.key)
 | |
|             eq_(hash(a_key.key), hash(b_key.key))
 | |
| 
 | |
|             for a_param, b_param in zip(a_key.bindparams, b_key.bindparams):
 | |
|                 assert a_param.compare(b_param, compare_values=compare_values)
 | |
|         return a_key, b_key
 | |
| 
 | |
|     def _run_cache_key_fixture(self, fixture, compare_values):
 | |
|         case_a = fixture()
 | |
|         case_b = fixture()
 | |
| 
 | |
|         for a, b in itertools.combinations_with_replacement(
 | |
|             range(len(case_a)), 2
 | |
|         ):
 | |
|             if a == b:
 | |
|                 a_key, b_key = self._compare_equal(
 | |
|                     case_a[a], case_b[b], compare_values
 | |
|                 )
 | |
|                 if a_key is None:
 | |
|                     continue
 | |
|             else:
 | |
|                 a_key = case_a[a]._generate_cache_key()
 | |
|                 b_key = case_b[b]._generate_cache_key()
 | |
| 
 | |
|                 if a_key is None or b_key is None:
 | |
|                     if a_key is None:
 | |
|                         assert case_a[a]._annotations.get("nocache")
 | |
|                     if b_key is None:
 | |
|                         assert case_b[b]._annotations.get("nocache")
 | |
|                     continue
 | |
| 
 | |
|                 if a_key.key == b_key.key:
 | |
|                     for a_param, b_param in zip(
 | |
|                         a_key.bindparams, b_key.bindparams
 | |
|                     ):
 | |
|                         if not a_param.compare(
 | |
|                             b_param, compare_values=compare_values
 | |
|                         ):
 | |
|                             break
 | |
|                     else:
 | |
|                         # this fails unconditionally since we could not
 | |
|                         # find bound parameter values that differed.
 | |
|                         # Usually we intended to get two distinct keys here
 | |
|                         # so the failure will be more descriptive using the
 | |
|                         # ne_() assertion.
 | |
|                         ne_(a_key.key, b_key.key)
 | |
|                 else:
 | |
|                     ne_(a_key.key, b_key.key)
 | |
| 
 | |
|             # ClauseElement-specific test to ensure the cache key
 | |
|             # collected all the bound parameters that aren't marked
 | |
|             # as "literal execute"
 | |
|             if isinstance(case_a[a], ClauseElement) and isinstance(
 | |
|                 case_b[b], ClauseElement
 | |
|             ):
 | |
|                 assert_a_params = []
 | |
|                 assert_b_params = []
 | |
| 
 | |
|                 for elem in visitors.iterate(case_a[a]):
 | |
|                     if elem.__visit_name__ == "bindparam":
 | |
|                         assert_a_params.append(elem)
 | |
| 
 | |
|                 for elem in visitors.iterate(case_b[b]):
 | |
|                     if elem.__visit_name__ == "bindparam":
 | |
|                         assert_b_params.append(elem)
 | |
| 
 | |
|                 # note we're asserting the order of the params as well as
 | |
|                 # if there are dupes or not.  ordering has to be
 | |
|                 # deterministic and matches what a traversal would provide.
 | |
|                 eq_(
 | |
|                     sorted(a_key.bindparams, key=lambda b: b.key),
 | |
|                     sorted(
 | |
|                         util.unique_list(assert_a_params), key=lambda b: b.key
 | |
|                     ),
 | |
|                 )
 | |
|                 eq_(
 | |
|                     sorted(b_key.bindparams, key=lambda b: b.key),
 | |
|                     sorted(
 | |
|                         util.unique_list(assert_b_params), key=lambda b: b.key
 | |
|                     ),
 | |
|                 )
 | |
| 
 | |
|     def _run_cache_key_equal_fixture(self, fixture, compare_values):
 | |
|         case_a = fixture()
 | |
|         case_b = fixture()
 | |
| 
 | |
|         for a, b in itertools.combinations_with_replacement(
 | |
|             range(len(case_a)), 2
 | |
|         ):
 | |
|             self._compare_equal(case_a[a], case_b[b], compare_values)
 | |
| 
 | |
| 
 | |
| def insertmanyvalues_fixture(
 | |
|     connection, randomize_rows=False, warn_on_downgraded=False
 | |
| ):
 | |
|     dialect = connection.dialect
 | |
|     orig_dialect = dialect._deliver_insertmanyvalues_batches
 | |
|     orig_conn = connection._exec_insertmany_context
 | |
| 
 | |
|     class RandomCursor:
 | |
|         __slots__ = ("cursor",)
 | |
| 
 | |
|         def __init__(self, cursor):
 | |
|             self.cursor = cursor
 | |
| 
 | |
|         # only this method is called by the deliver method.
 | |
|         # by not having the other methods we assert that those aren't being
 | |
|         # used
 | |
| 
 | |
|         @property
 | |
|         def description(self):
 | |
|             return self.cursor.description
 | |
| 
 | |
|         def fetchall(self):
 | |
|             rows = self.cursor.fetchall()
 | |
|             rows = list(rows)
 | |
|             random.shuffle(rows)
 | |
|             return rows
 | |
| 
 | |
|     def _deliver_insertmanyvalues_batches(
 | |
|         connection,
 | |
|         cursor,
 | |
|         statement,
 | |
|         parameters,
 | |
|         generic_setinputsizes,
 | |
|         context,
 | |
|     ):
 | |
|         if randomize_rows:
 | |
|             cursor = RandomCursor(cursor)
 | |
|         for batch in orig_dialect(
 | |
|             connection,
 | |
|             cursor,
 | |
|             statement,
 | |
|             parameters,
 | |
|             generic_setinputsizes,
 | |
|             context,
 | |
|         ):
 | |
|             if warn_on_downgraded and batch.is_downgraded:
 | |
|                 util.warn("Batches were downgraded for sorted INSERT")
 | |
| 
 | |
|             yield batch
 | |
| 
 | |
|     def _exec_insertmany_context(dialect, context):
 | |
|         with mock.patch.object(
 | |
|             dialect,
 | |
|             "_deliver_insertmanyvalues_batches",
 | |
|             new=_deliver_insertmanyvalues_batches,
 | |
|         ):
 | |
|             return orig_conn(dialect, context)
 | |
| 
 | |
|     connection._exec_insertmany_context = _exec_insertmany_context
 | 
