Source code for swh.provenance.backend.postgresql
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
import logging
from typing import List, Optional
import psycopg2.extras
import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.core.db.db_utils import swh_db_version
from swh.model.swhids import CoreSWHID, QualifiedSWHID
from swh.provenance.exc import ProvenanceDBError
logger = logging.getLogger(__name__)
[docs]
class Db(BaseDb):
"""
PostgreSQL backend for the Software Heritage provenance index.
"""
[docs]
class PostgresqlProvenance:
current_version: int = 1
def __init__(self, db, min_pool_conns=1, max_pool_conns=10):
try:
if isinstance(db, psycopg2.extensions.connection):
self._pool = None
self._db = Db(db)
# See comment below
self._db.cursor().execute("SET TIME ZONE 'UTC'")
else:
self._pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns, max_pool_conns, db
)
self._db = None
except psycopg2.OperationalError as e:
raise ProvenanceDBError(e)
[docs]
def get_db(self):
if self._db:
return self._db
else:
db = Db.from_pool(self._pool)
# Workaround for psycopg2 < 2.9.0 not handling fractional timezones,
# which may happen on old revision/release dates on systems configured
# with non-UTC timezones.
# https://www.psycopg.org/docs/usage.html#time-zones-handling
db.cursor().execute("SET TIME ZONE 'UTC'")
return db
[docs]
def put_db(self, db):
if db is not self._db:
db.put_conn()
[docs]
@contextmanager
def db(self):
db = None
try:
db = self.get_db()
yield db
finally:
if db:
self.put_db(db)
[docs]
@db_transaction()
def check_config(self, *, check_write: bool, db: Db, cur=None) -> bool:
dbversion = swh_db_version(db.conn.dsn)
if dbversion != self.current_version:
logger.warning(
"database dbversion (%s) != %s current_version (%s)",
dbversion,
__name__,
self.current_version,
)
return False
# Check permissions on one of the tables
check = "INSERT" if check_write else "SELECT"
cur.execute(
"select has_table_privilege(current_user, 'content_in_revision', %s)",
(check,),
)
return cur.fetchone()[0]
[docs]
@db_transaction()
def whereis(
self, swhid: CoreSWHID, *, db: Db, cur=None
) -> Optional[QualifiedSWHID]:
return QualifiedSWHID(
object_type=swhid.object_type,
object_id=swhid.object_id,
)
[docs]
@db_transaction()
def whereare(self, *, swhids: List[CoreSWHID]) -> List[Optional[QualifiedSWHID]]:
"""Given a list SWHID return a list of provenance info:
See `whereis` documentation for details on the provenance info.
"""
return [self.whereis(swhid=si) for si in swhids]