Source code for swh.storage.proxies.blocking
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
from typing import Dict, Iterable, Iterator, List, Optional, Union
import warnings
import psycopg_pool
from swh.model.model import Origin, OriginVisit, OriginVisitStatus
from swh.storage import get_storage
from swh.storage.exc import BlockedOriginException
from swh.storage.interface import StorageInterface
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.blocking.db import BlockingState
from .db import BlockingQuery
BLOCKING_OVERHEAD_METRIC = "swh_storage_blocking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, blocking_db=None, **kwargs):
assert cls in ("postgresql", "blocking")
from .db import BlockingAdmin
if db is None:
db = blocking_db
return BlockingAdmin.connect(db)
[docs]
def blocking_overhead_timer(method_name: str) -> DifferentialTimer:
"""Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
return DifferentialTimer(BLOCKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs]
class BlockingProxyStorage:
"""Blocking storage proxy
This proxy prevents visits from a known list of origins to be performed at all.
It uses a specific PostgreSQL database (which for now is colocated with the
swh.storage PostgreSQL database), the access to which is implemented in the
:mod:`.db` submodule.
Sample configuration
.. code-block: yaml
storage:
cls: blocking
db: 'dbname=swh-blocking-proxy'
max_pool_conns: 10
storage:
- cls: remote
url: http://storage.internal.staging.swh.network:5002/
"""
def __init__(
self,
storage: Union[Dict, StorageInterface],
db: Optional[str] = None,
blocking_db: Optional[str] = None,
min_pool_conns: int = 1,
max_pool_conns: int = 5,
):
if db is None:
assert blocking_db is not None
warnings.warn(
"'blocking_db' field in the blocking storage configuration "
"was renamed 'db' field",
DeprecationWarning,
)
db = blocking_db
self.storage: StorageInterface = (
get_storage(**storage) if isinstance(storage, dict) else storage
)
self._blocking_pool = psycopg_pool.ConnectionPool(
db,
min_size=min_pool_conns,
max_size=max_pool_conns,
)
[docs]
def origin_visit_status_add(
self, visit_statuses: List[OriginVisitStatus]
) -> Dict[str, int]:
with self._blocking_query() as q:
statuses = q.origins_are_blocked([v.origin for v in visit_statuses])
if statuses and any(
status.state != BlockingState.NON_BLOCKED
for status in statuses.values()
):
raise BlockedOriginException(statuses)
return self.storage.origin_visit_status_add(visit_statuses)
[docs]
def origin_visit_add(self, visits: List[OriginVisit]) -> Iterable[OriginVisit]:
with self._blocking_query() as q:
statuses = q.origins_are_blocked([v.origin for v in visits])
if statuses and any(
status.state != BlockingState.NON_BLOCKED
for status in statuses.values()
):
raise BlockedOriginException(statuses)
return self.storage.origin_visit_add(visits)
[docs]
def origin_add(self, origins: List[Origin]) -> Dict[str, int]:
with self._blocking_query() as q:
statuses = {}
for origin in origins:
status = q.origin_is_blocked(origin.url)
if status and status.state != BlockingState.NON_BLOCKED:
statuses[origin.url] = status
if statuses:
raise BlockedOriginException(statuses)
return self.storage.origin_add(origins)
@contextmanager
def _blocking_query(self) -> Iterator[BlockingQuery]:
ret = None
try:
ret = BlockingQuery.from_pool(self._blocking_pool)
yield ret
finally:
if ret:
ret.put_conn()
def __getattr__(self, key):
method = getattr(self.storage, key)
if method:
# Avoid going through __getattr__ again next time
setattr(self, key, method)
return method
# Raise a NotImplementedError to make sure we don't forget to add
# masking to any new storage functions
raise NotImplementedError(key)