503 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			503 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # testing/suite/test_results.py
 | |
| # Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
 | |
| # <see AUTHORS file>
 | |
| #
 | |
| # This module is part of SQLAlchemy and is released under
 | |
| # the MIT License: https://www.opensource.org/licenses/mit-license.php
 | |
| # mypy: ignore-errors
 | |
| 
 | |
| import datetime
 | |
| import re
 | |
| 
 | |
| from .. import engines
 | |
| from .. import fixtures
 | |
| from ..assertions import eq_
 | |
| from ..config import requirements
 | |
| from ..schema import Column
 | |
| from ..schema import Table
 | |
| from ... import DateTime
 | |
| from ... import func
 | |
| from ... import Integer
 | |
| from ... import select
 | |
| from ... import sql
 | |
| from ... import String
 | |
| from ... import testing
 | |
| from ... import text
 | |
| 
 | |
| 
 | |
| class RowFetchTest(fixtures.TablesTest):
 | |
|     __backend__ = True
 | |
| 
 | |
|     @classmethod
 | |
|     def define_tables(cls, metadata):
 | |
|         Table(
 | |
|             "plain_pk",
 | |
|             metadata,
 | |
|             Column("id", Integer, primary_key=True),
 | |
|             Column("data", String(50)),
 | |
|         )
 | |
|         Table(
 | |
|             "has_dates",
 | |
|             metadata,
 | |
|             Column("id", Integer, primary_key=True),
 | |
|             Column("today", DateTime),
 | |
|         )
 | |
| 
 | |
|     @classmethod
 | |
|     def insert_data(cls, connection):
 | |
|         connection.execute(
 | |
|             cls.tables.plain_pk.insert(),
 | |
|             [
 | |
|                 {"id": 1, "data": "d1"},
 | |
|                 {"id": 2, "data": "d2"},
 | |
|                 {"id": 3, "data": "d3"},
 | |
|             ],
 | |
|         )
 | |
| 
 | |
|         connection.execute(
 | |
|             cls.tables.has_dates.insert(),
 | |
|             [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
 | |
|         )
 | |
| 
 | |
|     def test_via_attr(self, connection):
 | |
|         row = connection.execute(
 | |
|             self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
 | |
|         ).first()
 | |
| 
 | |
|         eq_(row.id, 1)
 | |
|         eq_(row.data, "d1")
 | |
| 
 | |
|     def test_via_string(self, connection):
 | |
|         row = connection.execute(
 | |
|             self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
 | |
|         ).first()
 | |
| 
 | |
|         eq_(row._mapping["id"], 1)
 | |
|         eq_(row._mapping["data"], "d1")
 | |
| 
 | |
|     def test_via_int(self, connection):
 | |
|         row = connection.execute(
 | |
|             self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
 | |
|         ).first()
 | |
| 
 | |
|         eq_(row[0], 1)
 | |
|         eq_(row[1], "d1")
 | |
| 
 | |
|     def test_via_col_object(self, connection):
 | |
|         row = connection.execute(
 | |
|             self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
 | |
|         ).first()
 | |
| 
 | |
|         eq_(row._mapping[self.tables.plain_pk.c.id], 1)
 | |
|         eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
 | |
| 
 | |
|     @requirements.duplicate_names_in_cursor_description
 | |
|     def test_row_with_dupe_names(self, connection):
 | |
|         result = connection.execute(
 | |
|             select(
 | |
|                 self.tables.plain_pk.c.data,
 | |
|                 self.tables.plain_pk.c.data.label("data"),
 | |
|             ).order_by(self.tables.plain_pk.c.id)
 | |
|         )
 | |
|         row = result.first()
 | |
|         eq_(result.keys(), ["data", "data"])
 | |
|         eq_(row, ("d1", "d1"))
 | |
| 
 | |
|     def test_row_w_scalar_select(self, connection):
 | |
|         """test that a scalar select as a column is returned as such
 | |
|         and that type conversion works OK.
 | |
| 
 | |
|         (this is half a SQLAlchemy Core test and half to catch database
 | |
|         backends that may have unusual behavior with scalar selects.)
 | |
| 
 | |
|         """
 | |
|         datetable = self.tables.has_dates
 | |
|         s = select(datetable.alias("x").c.today).scalar_subquery()
 | |
|         s2 = select(datetable.c.id, s.label("somelabel"))
 | |
|         row = connection.execute(s2).first()
 | |
| 
 | |
|         eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
 | |
| 
 | |
| 
 | |
| class PercentSchemaNamesTest(fixtures.TablesTest):
 | |
|     """tests using percent signs, spaces in table and column names.
 | |
| 
 | |
|     This didn't work for PostgreSQL / MySQL drivers for a long time
 | |
|     but is now supported.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     __requires__ = ("percent_schema_names",)
 | |
| 
 | |
|     __backend__ = True
 | |
| 
 | |
|     @classmethod
 | |
|     def define_tables(cls, metadata):
 | |
|         cls.tables.percent_table = Table(
 | |
|             "percent%table",
 | |
|             metadata,
 | |
|             Column("percent%", Integer),
 | |
|             Column("spaces % more spaces", Integer),
 | |
|         )
 | |
|         cls.tables.lightweight_percent_table = sql.table(
 | |
|             "percent%table",
 | |
|             sql.column("percent%"),
 | |
|             sql.column("spaces % more spaces"),
 | |
|         )
 | |
| 
 | |
|     def test_single_roundtrip(self, connection):
 | |
|         percent_table = self.tables.percent_table
 | |
|         for params in [
 | |
|             {"percent%": 5, "spaces % more spaces": 12},
 | |
|             {"percent%": 7, "spaces % more spaces": 11},
 | |
|             {"percent%": 9, "spaces % more spaces": 10},
 | |
|             {"percent%": 11, "spaces % more spaces": 9},
 | |
|         ]:
 | |
|             connection.execute(percent_table.insert(), params)
 | |
|         self._assert_table(connection)
 | |
| 
 | |
|     def test_executemany_roundtrip(self, connection):
 | |
|         percent_table = self.tables.percent_table
 | |
|         connection.execute(
 | |
|             percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
 | |
|         )
 | |
|         connection.execute(
 | |
|             percent_table.insert(),
 | |
|             [
 | |
|                 {"percent%": 7, "spaces % more spaces": 11},
 | |
|                 {"percent%": 9, "spaces % more spaces": 10},
 | |
|                 {"percent%": 11, "spaces % more spaces": 9},
 | |
|             ],
 | |
|         )
 | |
|         self._assert_table(connection)
 | |
| 
 | |
|     @requirements.insert_executemany_returning
 | |
|     def test_executemany_returning_roundtrip(self, connection):
 | |
|         percent_table = self.tables.percent_table
 | |
|         connection.execute(
 | |
|             percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
 | |
|         )
 | |
|         result = connection.execute(
 | |
|             percent_table.insert().returning(
 | |
|                 percent_table.c["percent%"],
 | |
|                 percent_table.c["spaces % more spaces"],
 | |
|             ),
 | |
|             [
 | |
|                 {"percent%": 7, "spaces % more spaces": 11},
 | |
|                 {"percent%": 9, "spaces % more spaces": 10},
 | |
|                 {"percent%": 11, "spaces % more spaces": 9},
 | |
|             ],
 | |
|         )
 | |
|         eq_(result.all(), [(7, 11), (9, 10), (11, 9)])
 | |
|         self._assert_table(connection)
 | |
| 
 | |
|     def _assert_table(self, conn):
 | |
|         percent_table = self.tables.percent_table
 | |
|         lightweight_percent_table = self.tables.lightweight_percent_table
 | |
| 
 | |
|         for table in (
 | |
|             percent_table,
 | |
|             percent_table.alias(),
 | |
|             lightweight_percent_table,
 | |
|             lightweight_percent_table.alias(),
 | |
|         ):
 | |
|             eq_(
 | |
|                 list(
 | |
|                     conn.execute(table.select().order_by(table.c["percent%"]))
 | |
|                 ),
 | |
|                 [(5, 12), (7, 11), (9, 10), (11, 9)],
 | |
|             )
 | |
| 
 | |
|             eq_(
 | |
|                 list(
 | |
|                     conn.execute(
 | |
|                         table.select()
 | |
|                         .where(table.c["spaces % more spaces"].in_([9, 10]))
 | |
|                         .order_by(table.c["percent%"])
 | |
|                     )
 | |
|                 ),
 | |
|                 [(9, 10), (11, 9)],
 | |
|             )
 | |
| 
 | |
|             row = conn.execute(
 | |
|                 table.select().order_by(table.c["percent%"])
 | |
|             ).first()
 | |
|             eq_(row._mapping["percent%"], 5)
 | |
|             eq_(row._mapping["spaces % more spaces"], 12)
 | |
| 
 | |
|             eq_(row._mapping[table.c["percent%"]], 5)
 | |
|             eq_(row._mapping[table.c["spaces % more spaces"]], 12)
 | |
| 
 | |
|         conn.execute(
 | |
|             percent_table.update().values(
 | |
|                 {percent_table.c["spaces % more spaces"]: 15}
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         eq_(
 | |
|             list(
 | |
|                 conn.execute(
 | |
|                     percent_table.select().order_by(
 | |
|                         percent_table.c["percent%"]
 | |
|                     )
 | |
|                 )
 | |
|             ),
 | |
|             [(5, 15), (7, 15), (9, 15), (11, 15)],
 | |
|         )
 | |
| 
 | |
| 
 | |
| class ServerSideCursorsTest(
 | |
|     fixtures.TestBase, testing.AssertsExecutionResults
 | |
| ):
 | |
|     __requires__ = ("server_side_cursors",)
 | |
| 
 | |
|     __backend__ = True
 | |
| 
 | |
|     def _is_server_side(self, cursor):
 | |
|         # TODO: this is a huge issue as it prevents these tests from being
 | |
|         # usable by third party dialects.
 | |
|         if self.engine.dialect.driver == "psycopg2":
 | |
|             return bool(cursor.name)
 | |
|         elif self.engine.dialect.driver == "pymysql":
 | |
|             sscursor = __import__("pymysql.cursors").cursors.SSCursor
 | |
|             return isinstance(cursor, sscursor)
 | |
|         elif self.engine.dialect.driver in ("aiomysql", "asyncmy", "aioodbc"):
 | |
|             return cursor.server_side
 | |
|         elif self.engine.dialect.driver == "mysqldb":
 | |
|             sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
 | |
|             return isinstance(cursor, sscursor)
 | |
|         elif self.engine.dialect.driver == "mariadbconnector":
 | |
|             return not cursor.buffered
 | |
|         elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
 | |
|             return cursor.server_side
 | |
|         elif self.engine.dialect.driver == "pg8000":
 | |
|             return getattr(cursor, "server_side", False)
 | |
|         elif self.engine.dialect.driver == "psycopg":
 | |
|             return bool(getattr(cursor, "name", False))
 | |
|         elif self.engine.dialect.driver == "oracledb":
 | |
|             return getattr(cursor, "server_side", False)
 | |
|         else:
 | |
|             return False
 | |
| 
 | |
|     def _fixture(self, server_side_cursors):
 | |
|         if server_side_cursors:
 | |
|             with testing.expect_deprecated(
 | |
|                 "The create_engine.server_side_cursors parameter is "
 | |
|                 "deprecated and will be removed in a future release.  "
 | |
|                 "Please use the Connection.execution_options.stream_results "
 | |
|                 "parameter."
 | |
|             ):
 | |
|                 self.engine = engines.testing_engine(
 | |
|                     options={"server_side_cursors": server_side_cursors}
 | |
|                 )
 | |
|         else:
 | |
|             self.engine = engines.testing_engine(
 | |
|                 options={"server_side_cursors": server_side_cursors}
 | |
|             )
 | |
|         return self.engine
 | |
| 
 | |
|     def stringify(self, str_):
 | |
|         return re.compile(r"SELECT (\d+)", re.I).sub(
 | |
|             lambda m: str(select(int(m.group(1))).compile(testing.db)), str_
 | |
|         )
 | |
| 
 | |
|     @testing.combinations(
 | |
|         ("global_string", True, lambda stringify: stringify("select 1"), True),
 | |
|         (
 | |
|             "global_text",
 | |
|             True,
 | |
|             lambda stringify: text(stringify("select 1")),
 | |
|             True,
 | |
|         ),
 | |
|         ("global_expr", True, select(1), True),
 | |
|         (
 | |
|             "global_off_explicit",
 | |
|             False,
 | |
|             lambda stringify: text(stringify("select 1")),
 | |
|             False,
 | |
|         ),
 | |
|         (
 | |
|             "stmt_option",
 | |
|             False,
 | |
|             select(1).execution_options(stream_results=True),
 | |
|             True,
 | |
|         ),
 | |
|         (
 | |
|             "stmt_option_disabled",
 | |
|             True,
 | |
|             select(1).execution_options(stream_results=False),
 | |
|             False,
 | |
|         ),
 | |
|         ("for_update_expr", True, select(1).with_for_update(), True),
 | |
|         # TODO: need a real requirement for this, or dont use this test
 | |
|         (
 | |
|             "for_update_string",
 | |
|             True,
 | |
|             lambda stringify: stringify("SELECT 1 FOR UPDATE"),
 | |
|             True,
 | |
|             testing.skip_if(["sqlite", "mssql"]),
 | |
|         ),
 | |
|         (
 | |
|             "text_no_ss",
 | |
|             False,
 | |
|             lambda stringify: text(stringify("select 42")),
 | |
|             False,
 | |
|         ),
 | |
|         (
 | |
|             "text_ss_option",
 | |
|             False,
 | |
|             lambda stringify: text(stringify("select 42")).execution_options(
 | |
|                 stream_results=True
 | |
|             ),
 | |
|             True,
 | |
|         ),
 | |
|         id_="iaaa",
 | |
|         argnames="engine_ss_arg, statement, cursor_ss_status",
 | |
|     )
 | |
|     def test_ss_cursor_status(
 | |
|         self, engine_ss_arg, statement, cursor_ss_status
 | |
|     ):
 | |
|         engine = self._fixture(engine_ss_arg)
 | |
|         with engine.begin() as conn:
 | |
|             if callable(statement):
 | |
|                 statement = testing.resolve_lambda(
 | |
|                     statement, stringify=self.stringify
 | |
|                 )
 | |
| 
 | |
|             if isinstance(statement, str):
 | |
|                 result = conn.exec_driver_sql(statement)
 | |
|             else:
 | |
|                 result = conn.execute(statement)
 | |
|             eq_(self._is_server_side(result.cursor), cursor_ss_status)
 | |
|             result.close()
 | |
| 
 | |
|     def test_conn_option(self):
 | |
|         engine = self._fixture(False)
 | |
| 
 | |
|         with engine.connect() as conn:
 | |
|             # should be enabled for this one
 | |
|             result = conn.execution_options(
 | |
|                 stream_results=True
 | |
|             ).exec_driver_sql(self.stringify("select 1"))
 | |
|             assert self._is_server_side(result.cursor)
 | |
| 
 | |
|             # the connection has autobegun, which means at the end of the
 | |
|             # block, we will roll back, which on MySQL at least will fail
 | |
|             # with "Commands out of sync" if the result set
 | |
|             # is not closed, so we close it first.
 | |
|             #
 | |
|             # fun fact!  why did we not have this result.close() in this test
 | |
|             # before 2.0? don't we roll back in the connection pool
 | |
|             # unconditionally? yes!  and in fact if you run this test in 1.4
 | |
|             # with stdout shown, there is in fact "Exception during reset or
 | |
|             # similar" with "Commands out sync" emitted a warning!  2.0's
 | |
|             # architecture finds and fixes what was previously an expensive
 | |
|             # silent error condition.
 | |
|             result.close()
 | |
| 
 | |
|     def test_stmt_enabled_conn_option_disabled(self):
 | |
|         engine = self._fixture(False)
 | |
| 
 | |
|         s = select(1).execution_options(stream_results=True)
 | |
| 
 | |
|         with engine.connect() as conn:
 | |
|             # not this one
 | |
|             result = conn.execution_options(stream_results=False).execute(s)
 | |
|             assert not self._is_server_side(result.cursor)
 | |
| 
 | |
|     def test_aliases_and_ss(self):
 | |
|         engine = self._fixture(False)
 | |
|         s1 = (
 | |
|             select(sql.literal_column("1").label("x"))
 | |
|             .execution_options(stream_results=True)
 | |
|             .subquery()
 | |
|         )
 | |
| 
 | |
|         # options don't propagate out when subquery is used as a FROM clause
 | |
|         with engine.begin() as conn:
 | |
|             result = conn.execute(s1.select())
 | |
|             assert not self._is_server_side(result.cursor)
 | |
|             result.close()
 | |
| 
 | |
|         s2 = select(1).select_from(s1)
 | |
|         with engine.begin() as conn:
 | |
|             result = conn.execute(s2)
 | |
|             assert not self._is_server_side(result.cursor)
 | |
|             result.close()
 | |
| 
 | |
|     def test_roundtrip_fetchall(self, metadata):
 | |
|         md = self.metadata
 | |
| 
 | |
|         engine = self._fixture(True)
 | |
|         test_table = Table(
 | |
|             "test_table",
 | |
|             md,
 | |
|             Column(
 | |
|                 "id", Integer, primary_key=True, test_needs_autoincrement=True
 | |
|             ),
 | |
|             Column("data", String(50)),
 | |
|         )
 | |
| 
 | |
|         with engine.begin() as connection:
 | |
|             test_table.create(connection, checkfirst=True)
 | |
|             connection.execute(test_table.insert(), dict(data="data1"))
 | |
|             connection.execute(test_table.insert(), dict(data="data2"))
 | |
|             eq_(
 | |
|                 connection.execute(
 | |
|                     test_table.select().order_by(test_table.c.id)
 | |
|                 ).fetchall(),
 | |
|                 [(1, "data1"), (2, "data2")],
 | |
|             )
 | |
|             connection.execute(
 | |
|                 test_table.update()
 | |
|                 .where(test_table.c.id == 2)
 | |
|                 .values(data=test_table.c.data + " updated")
 | |
|             )
 | |
|             eq_(
 | |
|                 connection.execute(
 | |
|                     test_table.select().order_by(test_table.c.id)
 | |
|                 ).fetchall(),
 | |
|                 [(1, "data1"), (2, "data2 updated")],
 | |
|             )
 | |
|             connection.execute(test_table.delete())
 | |
|             eq_(
 | |
|                 connection.scalar(
 | |
|                     select(func.count("*")).select_from(test_table)
 | |
|                 ),
 | |
|                 0,
 | |
|             )
 | |
| 
 | |
|     def test_roundtrip_fetchmany(self, metadata):
 | |
|         md = self.metadata
 | |
| 
 | |
|         engine = self._fixture(True)
 | |
|         test_table = Table(
 | |
|             "test_table",
 | |
|             md,
 | |
|             Column(
 | |
|                 "id", Integer, primary_key=True, test_needs_autoincrement=True
 | |
|             ),
 | |
|             Column("data", String(50)),
 | |
|         )
 | |
| 
 | |
|         with engine.begin() as connection:
 | |
|             test_table.create(connection, checkfirst=True)
 | |
|             connection.execute(
 | |
|                 test_table.insert(),
 | |
|                 [dict(data="data%d" % i) for i in range(1, 20)],
 | |
|             )
 | |
| 
 | |
|             result = connection.execute(
 | |
|                 test_table.select().order_by(test_table.c.id)
 | |
|             )
 | |
| 
 | |
|             eq_(
 | |
|                 result.fetchmany(5),
 | |
|                 [(i, "data%d" % i) for i in range(1, 6)],
 | |
|             )
 | |
|             eq_(
 | |
|                 result.fetchmany(10),
 | |
|                 [(i, "data%d" % i) for i in range(6, 16)],
 | |
|             )
 | |
|             eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
 | 
