Source code for swh.storage.proxies.blocking

# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from contextlib import contextmanager
from typing import Dict, Iterable, Iterator, List, Union

import psycopg2.pool

from swh.model.model import Origin, OriginVisit, OriginVisitStatus
from swh.storage import get_storage
from swh.storage.exc import BlockedOriginException
from swh.storage.interface import StorageInterface
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.blocking.db import BlockingState

from .db import BlockingQuery

BLOCKING_OVERHEAD_METRIC = "swh_storage_blocking_overhead_seconds"


[docs] def get_datastore(cls, db): assert cls == "postgresql" from .db import BlockingAdmin return BlockingAdmin.connect(db)
[docs] def blocking_overhead_timer(method_name: str) -> DifferentialTimer: """Return a properly setup DifferentialTimer for ``method_name`` of the storage""" return DifferentialTimer(BLOCKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs] class BlockingProxyStorage: """Blocking storage proxy This proxy prevents visits from a known list of origins to be performed at all. It uses a specific PostgreSQL database (which for now is colocated with the swh.storage PostgreSQL database), the access to which is implemented in the :mod:`.db` submodule. Sample configuration .. code-block: yaml storage: cls: blocking blocking_db: 'dbname=swh-blocking-proxy' max_pool_conns: 10 storage: - cls: remote url: http://storage.internal.staging.swh.network:5002/ """ def __init__( self, blocking_db: str, storage: Union[Dict, StorageInterface], min_pool_conns: int = 1, max_pool_conns: int = 5, ): self.storage: StorageInterface = ( get_storage(**storage) if isinstance(storage, dict) else storage ) self._blocking_pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, blocking_db )
[docs] def origin_visit_status_add( self, visit_statuses: List[OriginVisitStatus] ) -> Dict[str, int]: with self._blocking_query() as q: statuses = q.origins_are_blocked([v.origin for v in visit_statuses]) if statuses and any( status.state != BlockingState.NON_BLOCKED for status in statuses.values() ): raise BlockedOriginException(statuses) return self.storage.origin_visit_status_add(visit_statuses)
[docs] def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]: with self._blocking_query() as q: statuses = q.origins_are_blocked([v.origin for v in visits]) if statuses and any( status.state != BlockingState.NON_BLOCKED for status in statuses.values() ): raise BlockedOriginException(statuses) return self.storage.origin_visit_add(visits)
[docs] def origin_add(self, origins: List[Origin]) -> Dict[str, int]: with self._blocking_query() as q: statuses = {} for origin in origins: status = q.origin_is_blocked(origin.url) if status and status.state != BlockingState.NON_BLOCKED: statuses[origin.url] = status if statuses: raise BlockedOriginException(statuses) return self.storage.origin_add(origins)
@contextmanager def _blocking_query(self) -> Iterator[BlockingQuery]: ret = None try: ret = BlockingQuery.from_pool(self._blocking_pool) yield ret finally: if ret: ret.put_conn() def __getattr__(self, key): method = getattr(self.storage, key) if method: # Avoid going through __getattr__ again next time setattr(self, key, method) return method # Raise a NotImplementedError to make sure we don't forget to add # masking to any new storage functions raise NotImplementedError(key)