1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
|
From 90d5b8e220db51465e4dbac8df6e4bd4941c9ba6 Mon Sep 17 00:00:00 2001
From: Steve Kowalik <steven@wedontsleep.org>
Date: Tue, 26 Sep 2023 11:59:39 +1000
Subject: [PATCH] Migrate to SQLAlchemy 2
https://github.com/wireservice/agate-sql/pull/40
Remove the upper bound on SQLAlchemy by converting the code idioms in
use to support both SQLAlchemy 1.4 and SQLAlchemy 2, and only setting a
lower bound SQLAlchemy of >= 1.4.
Closes #39
diff --git a/agatesql/table.py b/agatesql/table.py
index b141937..e4efe91 100644
--- a/agatesql/table.py
+++ b/agatesql/table.py
@@ -82,2 +82,2 @@ def from_sql(cls, connection_or_string, table_name):
- metadata = MetaData(connection)
- sql_table = Table(table_name, metadata, autoload=True, autoload_with=connection)
+ metadata = MetaData()
+ sql_table = Table(table_name, metadata, autoload_with=connection)
@@ -113 +113 @@ def from_sql(cls, connection_or_string, table_name):
- s = select([sql_table])
+ s = select(sql_table)
@@ -182 +182 @@ def make_sql_table(table, table_name, dialect=None, db_schema=None, constraints=
- metadata = MetaData(connection)
+ metadata = MetaData()
@@ -276,2 +276,3 @@ def to_sql(self, connection_or_string, table_name, overwrite=False,
- if overwrite:
- sql_table.drop(checkfirst=True)
+ with connection.begin():
+ if overwrite:
+ sql_table.drop(bind=connection, checkfirst=True)
@@ -279 +280 @@ def to_sql(self, connection_or_string, table_name, overwrite=False,
- sql_table.create(checkfirst=create_if_not_exists)
+ sql_table.create(bind=connection, checkfirst=create_if_not_exists)
@@ -282,13 +283,14 @@ def to_sql(self, connection_or_string, table_name, overwrite=False,
- insert = sql_table.insert()
- for prefix in prefixes:
- insert = insert.prefix_with(prefix)
- if chunk_size is None:
- connection.execute(insert, [dict(zip(self.column_names, row)) for row in self.rows])
- else:
- number_of_rows = len(self.rows)
- for index in range((number_of_rows - 1) // chunk_size + 1):
- end_index = (index + 1) * chunk_size
- if end_index > number_of_rows:
- end_index = number_of_rows
- connection.execute(insert, [dict(zip(self.column_names, row)) for row in
- self.rows[index * chunk_size:end_index]])
+ with connection.begin():
+ insert = sql_table.insert()
+ for prefix in prefixes:
+ insert = insert.prefix_with(prefix)
+ if chunk_size is None:
+ connection.execute(insert, [dict(zip(self.column_names, row)) for row in self.rows])
+ else:
+ number_of_rows = len(self.rows)
+ for index in range((number_of_rows - 1) // chunk_size + 1):
+ end_index = (index + 1) * chunk_size
+ if end_index > number_of_rows:
+ end_index = number_of_rows
+ connection.execute(insert, [dict(zip(self.column_names, row)) for row in
+ self.rows[index * chunk_size:end_index]])
@@ -354 +356 @@ def sql_query(self, query, table_name='agate'):
- rows = connection.execute(q)
+ rows = connection.exec_driver_sql(q)
diff --git a/setup.py b/setup.py
index 3905203..7257399 100644
--- a/setup.py
+++ b/setup.py
@@ -37 +37 @@ setup(
- 'sqlalchemy<2',
+ 'sqlalchemy>=1.4',
|