Source code for

# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from contextlib import contextmanager
import functools
import inspect
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union

import psycopg2.pool

from swh.core.api.classes import PagedResult
from swh.model.hashutil import MultiHash
from swh.model.model import Origin
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from import get_storage
from import MaskedObjectException
from import HashDict, Sha1, StorageInterface
from import DifferentialTimer
from import MaskedStatus

from .db import MaskingQuery

MASKING_OVERHEAD_METRIC = "swh_storage_masking_overhead_seconds"

[docs] def get_datastore(cls, db): assert cls == "postgresql" from .db import MaskingAdmin return MaskingAdmin.connect(db)
[docs] def masking_overhead_timer(method_name: str) -> DifferentialTimer: """Return a properly setup DifferentialTimer for ``method_name`` of the storage""" return DifferentialTimer(MASKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs] class MaskingProxyStorage: """Masking storage proxy This proxy filters out objects that are present in the archive but should not be displayed for policy reasons. It uses a specific PostgreSQL database (which for now is colocated with the PostgreSQL database), the access to which is implemented in the :mod:`.db` submodule. Sample configuration .. code-block: yaml storage: cls: masking masking_db: 'dbname=swh-storage' max_pool_conns: 10 storage: - cls: remote url: """ def __init__( self, masking_db: str, storage: Union[Dict, StorageInterface], min_pool_conns: int = 1, max_pool_conns: int = 5, ): StorageInterface = ( get_storage(**storage) if isinstance(storage, dict) else storage ) self._masking_pool = psycopg2.pool.ThreadedConnectionPool( min_pool_conns, max_pool_conns, masking_db ) # Generate the method dictionaries once per instantiation, instead of # doing it on every (first) __getattr__ call. self._gen_method_dicts() @contextmanager def _masking_query(self) -> Iterator[MaskingQuery]: ret = None try: ret = MaskingQuery.from_pool(self._masking_pool) yield ret finally: if ret: ret.put_conn() @staticmethod def _get_swhids_in_result(method_name: str, result: Any) -> List[ExtendedSWHID]: if result is None: raise TypeError(f"Filtering of Nones missing in {method_name}") result_type = getattr(result, "object_type", None) if result_type == "raw_extrinsic_metadata": # Raw Extrinsic Metadata have a swhid, but we also mask them if the target is masked return [result.swhid(),] if hasattr(result, "swhid"): swhid = result.swhid() if hasattr(swhid, "to_extended"): swhid = swhid.to_extended() assert isinstance( swhid, ExtendedSWHID ), f"{method_name} returned object with unexpected swhid() method return type" return [swhid] if result_type: if result_type == "extid": return [] if result_type.startswith("origin_visit"): return [Origin(url=result.origin).swhid()] if ( object_type := method_name.removesuffix("_get_random").removesuffix( "_get_id_partition" ) ) != method_name: # Returns bare sha1_gits assert isinstance(result, bytes), f"{method_name} returned unexpected type" return [ ExtendedSWHID( object_type=ExtendedObjectType[object_type.upper()], object_id=result, ) ] elif method_name == "revision_log": # returns dicts of revisions assert ( isinstance(result, dict) and "id" in result ), f"{method_name} returned unexpected type" return [ ExtendedSWHID( object_type=ExtendedObjectType.REVISION, object_id=result["id"] ) ] elif method_name == "revision_shortlog": # Returns tuples (, revision.parents) assert isinstance( result[0], bytes ), f"{method_name} returned unexpected type" return [ ExtendedSWHID( object_type=ExtendedObjectType.REVISION, object_id=result[0] ) ] elif method_name == "origin_get_by_sha1": # Returns origin dicts, because why not assert ( isinstance(result, dict) and "url" in result ), f"{method_name} returned unexpected type" return [Origin(url=result["url"]).swhid()] elif method_name == "origin_visit_get_with_statuses": # Returns an OriginVisitWithStatuses return [Origin(url=result.visit.origin).swhid()] elif method_name == "snapshot_get": # Returns a snapshot dict return [ ExtendedSWHID( object_type=ExtendedObjectType.SNAPSHOT, object_id=result["id"] ) ] raise ValueError(f"Cannot get swhid for result of method {method_name}") def _masked_result( self, method_name: str, result: Any ) -> Optional[Dict[ExtendedSWHID, List[MaskedStatus]]]: """Find the SWHIDs of the ``result`` object, and check if any of them is masked, returning the associated masking information.""" with self._masking_query() as q: return q.swhids_are_masked(self._get_swhids_in_result(method_name, result)) def _raise_if_masked_result(self, method_name: str, result: Any) -> None: """Raise a :exc:`MaskedObjectException` if ``result`` is masked.""" masked = self._masked_result(method_name, result) if masked: raise MaskedObjectException(masked) def _raise_if_masked_swhids(self, swhids: List[ExtendedSWHID]) -> None: """Raise a :exc:`MaskedObjectException` if any SWHID is masked.""" with self._masking_query() as q: masked = q.swhids_are_masked(swhids) if masked: raise MaskedObjectException(masked) def __getattr__(self, key): method = None if key in self._methods_by_name: method = self._methods_by_name[key](key) else: suffix = key.rsplit("_", 1)[-1] if suffix in self._methods_by_suffix: method = self._methods_by_suffix[suffix](key) if method: # Avoid going through __getattr__ again next time setattr(self, key, method) return method # Raise a NotImplementedError to make sure we don't forget to add # masking to any new storage functions raise NotImplementedError(key) def _gen_method_dicts(self): """Generate the :attr:`_methods_by_name` and :attr:`_methods_by_suffix` used by :meth:`__getattr__`""" _passthrough = functools.partial(getattr, self._methods_by_name = { # Returns a single object "snapshot_get": self._getter_optional, "origin_visit_find_by_date": self._getter_optional, "origin_visit_get_by": self._getter_optional, "origin_visit_status_get_latest": self._getter_optional, "origin_visit_get_latest": self._getter_optional, # Returns a PagedResult "origin_list": self._getter_pagedresult, "origin_visit_get": self._getter_pagedresult, "origin_search": self._getter_pagedresult, "raw_extrinsic_metadata_get": self._getter_pagedresult, "origin_visit_get_with_statuses": self._getter_pagedresult, "origin_visit_status_get": self._getter_pagedresult, # Returns a list of (optional) objects "origin_get_by_sha1": self._getter_list, "content_find": self._getter_list, "skipped_content_find": self._getter_list, "revision_log": self._getter_list, "revision_shortlog": self._getter_list, "extid_get_from_target": self._getter_list, "raw_extrinsic_metadata_get_by_ids": self._getter_list, # Filter arguments "directory_entry_get_by_path": self._getter_filtering_arguments, "directory_get_entries": self._getter_filtering_arguments, "directory_get_raw_manifest": self._getter_filtering_arguments, "directory_ls": self._getter_filtering_arguments, "raw_extrinsic_metadata_get_authorities": self._getter_filtering_arguments, "snapshot_branch_get_by_name": self._getter_filtering_arguments, "snapshot_count_branches": self._getter_filtering_arguments, "snapshot_get_branches": self._getter_filtering_arguments, "origin_snapshot_get_all": self._getter_filtering_arguments, # Content functions that don't match common getter or adder suffixes "content_add_metadata": _passthrough, "content_missing_per_sha1": _passthrough, "content_missing_per_sha1_git": _passthrough, "content_update": _passthrough, # These objects aren't maskable "extid_get_from_extid": _passthrough, "object_find_by_sha1_git": _passthrough, "object_find_recent_references": _passthrough, "metadata_authority_get": _passthrough, "metadata_fetcher_get": _passthrough, # Utility methods "check_config": _passthrough, "clear_buffers": _passthrough, "flush": _passthrough, "origin_count": _passthrough, "refresh_stat_counters": _passthrough, "stat_counters": _passthrough, # For tests "journal_writer": _passthrough, } self._methods_by_suffix = { # These methods will never need do any masking "add": _passthrough, "missing": _passthrough, # Partitions return PagedResults "partition": self._getter_pagedresult, # Getters return lists of optional objects "get": self._getter_list, "random": self._getter_random, }
[docs] def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]: with masking_overhead_timer("content_get_data") as t: with t.inner(): ret = if ret is None: return None if isinstance(content, dict) and "sha1_git" in content: self._raise_if_masked_swhids( [ ExtendedSWHID( object_type=ExtendedObjectType.CONTENT, object_id=content["sha1_git"], ) ] ) else: # We did not get the SWHID of the object as argument, so we need to # hash the resulting content to check if its SWHID was masked. self._raise_if_masked_swhids( [ ExtendedSWHID( object_type=ExtendedObjectType.CONTENT, object_id=MultiHash.from_data(ret, ["sha1_git"]).digest()[ "sha1_git" ], ) ] ) return ret
def _get_swhids_in_args( self, method_name: str, parsed_args: Dict[str, Any] ) -> List[ExtendedSWHID]: """Extract SWHIDs from the parsed arguments of ``method_name``. Arguments: method_name: name of the called method parsed_args: arguments of the method parsed with :func:`inspect.getcallargs` """ if method_name in ("directory_entry_get_by_path", "directory_ls"): return [ ExtendedSWHID( object_type=ExtendedObjectType.DIRECTORY, object_id=parsed_args["directory"], ) ] elif method_name == "directory_get_entries": return [ ExtendedSWHID( object_type=ExtendedObjectType.DIRECTORY, object_id=parsed_args["directory_id"], ) ] elif method_name.startswith("snapshot_"): return [ ExtendedSWHID( object_type=ExtendedObjectType.SNAPSHOT, object_id=parsed_args["snapshot_id"], ) ] elif method_name == "raw_extrinsic_metadata_get_authorities": return [parsed_args["target"]] elif method_name == "directory_get_raw_manifest": return [ ExtendedSWHID( object_type=ExtendedObjectType.DIRECTORY, object_id=object_id, ) for object_id in parsed_args["directory_ids"] ] elif method_name == "origin_snapshot_get_all": return [Origin(url=parsed_args["origin_url"]).swhid()] else: raise ValueError(f"Cannot get swhid for arguments of method {method_name}") def _getter_filtering_arguments(self, method_name: str): """Handles methods that should filter on their argument, instead of the returned value. If the underlying storage returns :const:`None`, return it, else, raise :exc:`MaskedObjectException` if the requested object is masked""" @functools.wraps(getattr(, method_name)) def newf(*args, **kwargs): with masking_overhead_timer(method_name) as t: method = getattr(, method_name) with t.inner(): result = method(*args, **kwargs) if result is None: return None signature = inspect.signature(getattr(StorageInterface, method_name)) bound_args = signature.bind(self, *args, **kwargs) self._raise_if_masked_swhids( self._get_swhids_in_args(method_name, bound_args.arguments) ) return result return newf RANDOM_ATTEMPTS = 5 def _getter_random(self, method_name: str): """Handles methods returning a random object. Try :const:`RANDOM_ATTEMPTS` times for a non-masked object, and return it, else return :const:`None`.""" @functools.wraps(getattr(, method_name)) def newf(*args, **kwargs): with masking_overhead_timer(method_name) as t: method = getattr(, method_name) for _ in range(self.RANDOM_ATTEMPTS): with t.inner(): result = method(*args, **kwargs) if result is None: return None if not self._masked_result(method_name, result): return result return newf def _getter_optional(self, method_name: str): """Handles methods returning an optional object: if the return value is :const:`None`, return it, else, raise a :exc:`MaskedObjectException` if the return value should be masked.""" @functools.wraps(getattr(, method_name)) def newf(*args, **kwargs): with masking_overhead_timer(method_name) as t: method = getattr(, method_name) with t.inner(): result = method(*args, **kwargs) if result is None: return None self._raise_if_masked_result(method_name, result) return result return newf def _raise_if_masked_result_in_list( self, method_name: str, results: Iterable[Any] ) -> None: """Raise a :exc:`MaskedObjectException` if any non-:const:`None` object in ``results`` is masked.""" result_swhids = set() for result in results: if result is not None: result_swhids.update(self._get_swhids_in_result(method_name, result)) if result_swhids: self._raise_if_masked_swhids(list(result_swhids)) def _getter_list( self, method_name: str, ): """Handle methods returning a list (or a generator) of optional objects, raising :exc:`MaskedObjectException` for all the masked objects in the batch.""" @functools.wraps(getattr(, method_name)) def newf(*args, **kwargs): with masking_overhead_timer(method_name) as t: method = getattr(, method_name) with t.inner(): results = list(method(*args, **kwargs)) self._raise_if_masked_result_in_list(method_name, results) return results return newf def _getter_pagedresult(self, method_name: str) -> Callable[..., PagedResult]: """Handle methods returning a :cls:`PagedResult`, raising :exc:`MaskedObjectException` if some objects in the returned page are masked.""" @functools.wraps(getattr(, method_name)) def newf(*args, **kwargs) -> PagedResult: with masking_overhead_timer(method_name) as t: method = getattr(, method_name) with t.inner(): results = method(*args, **kwargs) self._raise_if_masked_result_in_list(method_name, results.results) return results return newf