diff --git a/databases/backends/postgres.py b/databases/backends/asyncpg.py similarity index 88% rename from databases/backends/postgres.py rename to databases/backends/asyncpg.py index c42688e1..98ac44ea 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/asyncpg.py @@ -7,7 +7,7 @@ from sqlalchemy.sql.ddl import DDLElement from databases.backends.common.records import Record, create_column_maps -from databases.backends.dialects.psycopg import dialect as psycopg_dialect +from databases.backends.dialects.psycopg import get_dialect from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ( ConnectionBackend, @@ -19,28 +19,15 @@ logger = logging.getLogger("databases") -class PostgresBackend(DatabaseBackend): +class AsyncpgBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = self._get_dialect() + self._dialect = get_dialect() self._pool = None - def _get_dialect(self) -> Dialect: - dialect = psycopg_dialect(paramstyle="pyformat") - - dialect.implicit_returning = True - dialect.supports_native_enum = True - dialect.supports_smallserial = True # 9.2+ - dialect._backslash_escapes = False - dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ - dialect._has_native_hstore = True - dialect.supports_native_decimal = True - - return dialect - def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options @@ -78,12 +65,12 @@ async def disconnect(self) -> None: await self._pool.close() self._pool = None - def connection(self) -> "PostgresConnection": - return PostgresConnection(self, self._dialect) + def connection(self) -> "AsyncpgConnection": + return AsyncpgConnection(self, self._dialect) -class PostgresConnection(ConnectionBackend): - def __init__(self, database: PostgresBackend, dialect: Dialect): +class AsyncpgConnection(ConnectionBackend): + def __init__(self, database: AsyncpgBackend, dialect: Dialect): self._database = database self._dialect = dialect self._connection: typing.Optional[asyncpg.connection.Connection] = None @@ -159,7 +146,7 @@ async def iterate( yield Record(row, result_columns, self._dialect, column_maps) def transaction(self) -> TransactionBackend: - return PostgresTransaction(connection=self) + return AsyncpgTransaction(connection=self) def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: compiled = query.compile( @@ -197,8 +184,8 @@ def raw_connection(self) -> asyncpg.connection.Connection: return self._connection -class PostgresTransaction(TransactionBackend): - def __init__(self, connection: PostgresConnection): +class AsyncpgTransaction(TransactionBackend): + def __init__(self, connection: AsyncpgConnection): self._connection = connection self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None diff --git a/databases/backends/dialects/psycopg.py b/databases/backends/dialects/psycopg.py index 07bd1880..1caf49fe 100644 --- a/databases/backends/dialects/psycopg.py +++ b/databases/backends/dialects/psycopg.py @@ -43,4 +43,13 @@ class PGDialect_psycopg(PGDialect): execution_ctx_cls = PGExecutionContext_psycopg -dialect = PGDialect_psycopg +def get_dialect() -> PGDialect_psycopg: + dialect = PGDialect_psycopg(paramstyle="pyformat") + dialect.implicit_returning = True + dialect.supports_native_enum = True + dialect.supports_smallserial = True # 9.2+ + dialect._backslash_escapes = False + dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+ + dialect._has_native_hstore = True + dialect.supports_native_decimal = True + return dialect diff --git a/databases/backends/psycopg.py b/databases/backends/psycopg.py new file mode 100644 index 00000000..981742ce --- /dev/null +++ b/databases/backends/psycopg.py @@ -0,0 +1,118 @@ +import typing +from collections.abc import Sequence + +import psycopg_pool +from sqlalchemy.sql import ClauseElement + +from databases.backends.dialects.psycopg import get_dialect +from databases.core import DatabaseURL +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + TransactionBackend, +) + + +class PsycopgBackend(DatabaseBackend): + def __init__( + self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any + ) -> None: + self._database_url = DatabaseURL(database_url) + self._options = options + self._dialect = get_dialect() + self._pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None + + async def connect(self) -> None: + if self._pool is not None: + return + + self._pool = psycopg_pool.AsyncConnectionPool( + self._database_url.url, open=False, **self._options) + await self._pool.open() + + async def disconnect(self) -> None: + if self._pool is None: + return + + await self._pool.close() + self._pool = None + + def connection(self) -> "PsycopgConnection": + return PsycopgConnection(self) + + +class PsycopgConnection(ConnectionBackend): + def __init__(self, database: PsycopgBackend) -> None: + self._database = database + + async def acquire(self) -> None: + if self._connection is not None: + return + + if self._database._pool is None: + raise RuntimeError("PsycopgBackend is not running") + + # TODO: Add configurable timeouts + self._connection = await self._database._pool.getconn() + + async def release(self) -> None: + if self._connection is None: + return + + await self._database._pool.putconn(self._connection) + self._connection = None + + async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]: + raise NotImplementedError() # pragma: no cover + + async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]: + raise NotImplementedError() # pragma: no cover + + async def fetch_val( + self, query: ClauseElement, column: typing.Any = 0 + ) -> typing.Any: + row = await self.fetch_one(query) + return None if row is None else row[column] + + async def execute(self, query: ClauseElement) -> typing.Any: + raise NotImplementedError() # pragma: no cover + + async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + raise NotImplementedError() # pragma: no cover + + async def iterate( + self, query: ClauseElement + ) -> typing.AsyncGenerator[typing.Mapping, None]: + raise NotImplementedError() # pragma: no cover + # mypy needs async iterators to contain a `yield` + # https://github.com/python/mypy/issues/5385#issuecomment-407281656 + yield True # pragma: no cover + + def transaction(self) -> "TransactionBackend": + raise NotImplementedError() # pragma: no cover + + @property + def raw_connection(self) -> typing.Any: + raise NotImplementedError() # pragma: no cover + + +class PsycopgTransaction(TransactionBackend): + async def start( + self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] + ) -> None: + raise NotImplementedError() # pragma: no cover + + async def commit(self) -> None: + raise NotImplementedError() # pragma: no cover + + async def rollback(self) -> None: + raise NotImplementedError() # pragma: no cover + + +class Record(Sequence): + @property + def _mapping(self) -> typing.Mapping: + raise NotImplementedError() # pragma: no cover + + def __getitem__(self, key: typing.Any) -> typing.Any: + raise NotImplementedError() # pragma: no cover diff --git a/databases/core.py b/databases/core.py index d55dd3c8..cba06ced 100644 --- a/databases/core.py +++ b/databases/core.py @@ -43,12 +43,16 @@ class Database: SUPPORTED_BACKENDS = { - "postgresql": "databases.backends.postgres:PostgresBackend", + "postgres": "databases.backends.asyncpg:AsyncpgBackend", + "postgresql": "databases.backends.asyncpg:AsyncpgBackend", "postgresql+aiopg": "databases.backends.aiopg:AiopgBackend", - "postgres": "databases.backends.postgres:PostgresBackend", + "postgresql+asyncpg": "databases.backends.asyncpg:AsyncpgBackend", + "postgresql+psycopg": "databases.backends.psycopg:PsycopgBackend", "mysql": "databases.backends.mysql:MySQLBackend", + "mysql+aiomysql": "databases.backends.asyncmy:MySQLBackend", "mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend", "sqlite": "databases.backends.sqlite:SQLiteBackend", + "sqlite+aiosqlite": "databases.backends.sqlite:SQLiteBackend", } _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" diff --git a/tests/test_connection_options.py b/tests/test_connection_options.py index 81ce2ac7..757393a4 100644 --- a/tests/test_connection_options.py +++ b/tests/test_connection_options.py @@ -6,7 +6,7 @@ import pytest from databases.backends.aiopg import AiopgBackend -from databases.backends.postgres import PostgresBackend +from databases.backends.asyncpg import AsyncpgBackend from databases.core import DatabaseURL from tests.test_databases import DATABASE_URLS, async_adapter @@ -19,7 +19,7 @@ def test_postgres_pool_size(): - backend = PostgresBackend("postgres://localhost/database?min_size=1&max_size=20") + backend = AsyncpgBackend("postgres://localhost/database?min_size=1&max_size=20") kwargs = backend._get_connection_kwargs() assert kwargs == {"min_size": 1, "max_size": 20} @@ -29,43 +29,43 @@ async def test_postgres_pool_size_connect(): for url in DATABASE_URLS: if DatabaseURL(url).dialect != "postgresql": continue - backend = PostgresBackend(url + "?min_size=1&max_size=20") + backend = AsyncpgBackend(url + "?min_size=1&max_size=20") await backend.connect() await backend.disconnect() def test_postgres_explicit_pool_size(): - backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20) + backend = AsyncpgBackend("postgres://localhost/database", min_size=1, max_size=20) kwargs = backend._get_connection_kwargs() assert kwargs == {"min_size": 1, "max_size": 20} def test_postgres_ssl(): - backend = PostgresBackend("postgres://localhost/database?ssl=true") + backend = AsyncpgBackend("postgres://localhost/database?ssl=true") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} def test_postgres_ssl_verify_full(): - backend = PostgresBackend("postgres://localhost/database?ssl=verify-full") + backend = AsyncpgBackend("postgres://localhost/database?ssl=verify-full") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": "verify-full"} def test_postgres_explicit_ssl(): - backend = PostgresBackend("postgres://localhost/database", ssl=True) + backend = AsyncpgBackend("postgres://localhost/database", ssl=True) kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": True} def test_postgres_explicit_ssl_verify_full(): - backend = PostgresBackend("postgres://localhost/database", ssl="verify-full") + backend = AsyncpgBackend("postgres://localhost/database", ssl="verify-full") kwargs = backend._get_connection_kwargs() assert kwargs == {"ssl": "verify-full"} def test_postgres_no_extra_options(): - backend = PostgresBackend("postgres://localhost/database") + backend = AsyncpgBackend("postgres://localhost/database") kwargs = backend._get_connection_kwargs() assert kwargs == {} @@ -74,7 +74,7 @@ def test_postgres_password_as_callable(): def gen_password(): return "Foo" - backend = PostgresBackend( + backend = AsyncpgBackend( "postgres://:password@localhost/database", password=gen_password ) kwargs = backend._get_connection_kwargs()