Source code for swh.storage.proxies.blocking
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
from typing import Dict, Iterable, Iterator, List, Optional, Union
import warnings
import psycopg2.pool
from swh.model.model import Origin, OriginVisit, OriginVisitStatus
from swh.storage import get_storage
from swh.storage.exc import BlockedOriginException
from swh.storage.interface import StorageInterface
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.blocking.db import BlockingState
from .db import BlockingQuery
BLOCKING_OVERHEAD_METRIC = "swh_storage_blocking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, blocking_db=None, **kwargs):
assert cls in ("postgresql", "blocking")
from .db import BlockingAdmin
if db is None:
db = blocking_db
return BlockingAdmin.connect(db)
[docs]
def blocking_overhead_timer(method_name: str) -> DifferentialTimer:
"""Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
return DifferentialTimer(BLOCKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs]
class BlockingProxyStorage:
"""Blocking storage proxy
This proxy prevents visits from a known list of origins to be performed at all.
It uses a specific PostgreSQL database (which for now is colocated with the
swh.storage PostgreSQL database), the access to which is implemented in the
:mod:`.db` submodule.
Sample configuration
.. code-block: yaml
storage:
cls: blocking
db: 'dbname=swh-blocking-proxy'
max_pool_conns: 10
storage:
- cls: remote
url: http://storage.internal.staging.swh.network:5002/
"""
def __init__(
self,
storage: Union[Dict, StorageInterface],
db: Optional[str] = None,
blocking_db: Optional[str] = None,
min_pool_conns: int = 1,
max_pool_conns: int = 5,
):
if db is None:
assert blocking_db is not None
warnings.warn(
"'blocking_db' field in the blocking storage configuration "
"was renamed 'db' field",
DeprecationWarning,
)
db = blocking_db
self.storage: StorageInterface = (
get_storage(**storage) if isinstance(storage, dict) else storage
)
self._blocking_pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns, max_pool_conns, db
)
[docs]
def origin_visit_status_add(
self, visit_statuses: List[OriginVisitStatus]
) -> Dict[str, int]:
with self._blocking_query() as q:
statuses = q.origins_are_blocked([v.origin for v in visit_statuses])
if statuses and any(
status.state != BlockingState.NON_BLOCKED
for status in statuses.values()
):
raise BlockedOriginException(statuses)
return self.storage.origin_visit_status_add(visit_statuses)
[docs]
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
with self._blocking_query() as q:
statuses = q.origins_are_blocked([v.origin for v in visits])
if statuses and any(
status.state != BlockingState.NON_BLOCKED
for status in statuses.values()
):
raise BlockedOriginException(statuses)
return self.storage.origin_visit_add(visits)
[docs]
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
with self._blocking_query() as q:
statuses = {}
for origin in origins:
status = q.origin_is_blocked(origin.url)
if status and status.state != BlockingState.NON_BLOCKED:
statuses[origin.url] = status
if statuses:
raise BlockedOriginException(statuses)
return self.storage.origin_add(origins)
@contextmanager
def _blocking_query(self) -> Iterator[BlockingQuery]:
ret = None
try:
ret = BlockingQuery.from_pool(self._blocking_pool)
yield ret
finally:
if ret:
ret.put_conn()
def __getattr__(self, key):
method = getattr(self.storage, key)
if method:
# Avoid going through __getattr__ again next time
setattr(self, key, method)
return method
# Raise a NotImplementedError to make sure we don't forget to add
# masking to any new storage functions
raise NotImplementedError(key)