# Copyright (C) 2024-2025 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
import functools
import inspect
import itertools
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
)
import warnings
import attr
import psycopg_pool
from swh.core.utils import grouper
from swh.model.hashutil import MultiHash
from swh.model.model import (
ExtID,
Origin,
Person,
RawExtrinsicMetadata,
Release,
Revision,
Sha1Git,
)
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.storage import get_storage
from swh.storage.exc import MaskedObjectException
from swh.storage.interface import HashDict, PagedResult, Sha1, StorageInterface, TResult
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.masking.db import MaskedStatus
from .db import MaskingQuery
BATCH_SIZE = 1024
MASKING_OVERHEAD_METRIC = "swh_storage_masking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, masking_db=None, **kwargs):
assert cls in ("postgresql", "masking")
from .db import MaskingAdmin
if db is None:
db = masking_db
return MaskingAdmin.connect(db)
[docs]
def masking_overhead_timer(method_name: str) -> DifferentialTimer:
"""Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
return DifferentialTimer(MASKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs]
class MaskingProxyStorage:
"""Masking storage proxy
This proxy can return modified objects or stop them from being retrieved at
all.
It uses a specific PostgreSQL database (which for now is colocated with the
swh.storage PostgreSQL database), the access to which is implemented in the
:mod:`.db` submodule.
Sample configuration
.. code-block: yaml
storage:
cls: masking
db: 'dbname=swh-storage'
max_pool_conns: 10
storage:
- cls: remote
url: http://storage.internal.staging.swh.network:5002/
"""
def __init__(
self,
storage: Union[Dict, StorageInterface],
db: Optional[str] = None,
masking_db: Optional[str] = None,
min_pool_conns: int = 1,
max_pool_conns: int = 5,
):
if db is None:
assert masking_db is not None
warnings.warn(
"'masking_db' field in the masking storage configuration "
"was renamed 'db' field",
DeprecationWarning,
)
db = masking_db
self.storage: StorageInterface = (
get_storage(**storage) if isinstance(storage, dict) else storage
)
self._masking_pool = psycopg_pool.ConnectionPool(
db,
min_size=min_pool_conns,
max_size=max_pool_conns,
)
# Generate the method dictionaries once per instantiation, instead of
# doing it on every (first) __getattr__ call.
self._gen_method_dicts()
@contextmanager
def _masking_query(self) -> Iterator[MaskingQuery]:
ret = None
try:
ret = MaskingQuery.from_pool(self._masking_pool)
yield ret
finally:
if ret:
ret.put_conn()
@staticmethod
def _get_swhids_in_result(method_name: str, result: Any) -> List[ExtendedSWHID]:
if result is None:
raise TypeError(f"Filtering of Nones missing in {method_name}")
result_type = getattr(result, "object_type", None)
if result_type == RawExtrinsicMetadata.object_type:
# Raw Extrinsic Metadata have a swhid, but we also mask them if the target is masked
return [result.swhid(), result.target]
if hasattr(result, "swhid"):
swhid = result.swhid()
if hasattr(swhid, "to_extended"):
swhid = swhid.to_extended()
assert isinstance(
swhid, ExtendedSWHID
), f"{method_name} returned object with unexpected swhid() method return type"
return [swhid]
if result_type:
if result_type == ExtID.object_type:
return [result.target.to_extended()]
if result_type.value.startswith("origin_visit"):
return [Origin(url=result.origin).swhid()]
if (
object_type := method_name.removesuffix("_get_random").removesuffix(
"_get_id_partition"
)
) != method_name:
# Returns bare sha1_gits
assert isinstance(result, bytes), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType[object_type.upper()],
object_id=result,
)
]
elif method_name == "revision_log":
# returns dicts of revisions
assert (
isinstance(result, dict) and "id" in result
), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType.REVISION, object_id=result["id"]
)
]
elif method_name == "revision_shortlog":
# Returns tuples (revision.id, revision.parents)
assert isinstance(
result[0], bytes
), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType.REVISION, object_id=result[0]
)
]
elif method_name == "origin_get_by_sha1":
# Returns origin dicts, because why not
assert (
isinstance(result, dict) and "url" in result
), f"{method_name} returned unexpected type"
return [Origin(url=result["url"]).swhid()]
elif method_name == "origin_visit_get_with_statuses":
# Returns an OriginVisitWithStatuses
return [Origin(url=result.visit.origin).swhid()]
elif method_name == "snapshot_get":
# Returns a snapshot dict
return [
ExtendedSWHID(
object_type=ExtendedObjectType.SNAPSHOT, object_id=result["id"]
)
]
raise ValueError(f"Cannot get swhid for result of method {method_name}")
def _masked_result(
self, method_name: str, result: Any
) -> Optional[Dict[ExtendedSWHID, List[MaskedStatus]]]:
"""Find the SWHIDs of the ``result`` object, and check if any of them is
masked, returning the associated masking information."""
with self._masking_query() as q:
return q.swhids_are_masked(self._get_swhids_in_result(method_name, result))
def _raise_if_masked_result(self, method_name: str, result: Any) -> None:
"""Raise a :exc:`MaskedObjectException` if ``result`` is masked."""
masked = self._masked_result(method_name, result)
if masked:
raise MaskedObjectException(masked)
def _raise_if_masked_swhids(self, swhids: List[ExtendedSWHID]) -> None:
"""Raise a :exc:`MaskedObjectException` if any SWHID is masked."""
with self._masking_query() as q:
masked = q.swhids_are_masked(swhids)
if masked:
raise MaskedObjectException(masked)
def __getattr__(self, key):
method = None
if key in self._methods_by_name:
method = self._methods_by_name[key](key)
else:
suffix = key.rsplit("_", 1)[-1]
if suffix in self._methods_by_suffix:
method = self._methods_by_suffix[suffix](key)
if method:
# Avoid going through __getattr__ again next time
setattr(self, key, method)
return method
# Raise a NotImplementedError to make sure we don't forget to add
# masking to any new storage functions
raise NotImplementedError(key)
def _gen_method_dicts(self):
"""Generate the :attr:`_methods_by_name` and :attr:`_methods_by_suffix`
used by :meth:`__getattr__`"""
_passthrough = functools.partial(getattr, self.storage)
self._methods_by_name = {
# Returns a single object
"snapshot_get": self._getter_optional,
"origin_visit_find_by_date": self._getter_optional,
"origin_visit_get_by": self._getter_optional,
"origin_visit_status_get_latest": self._getter_optional,
"origin_visit_get_latest": self._getter_optional,
# Returns a PagedResult
"origin_list": self._getter_pagedresult,
"origin_visit_get": self._getter_pagedresult,
"origin_search": self._getter_pagedresult,
"raw_extrinsic_metadata_get": self._getter_pagedresult,
"origin_visit_get_with_statuses": self._getter_pagedresult,
"origin_visit_status_get": self._getter_pagedresult,
# Returns a list of (optional) objects
"origin_get_by_sha1": self._getter_list,
"content_find": self._getter_list,
"skipped_content_find": self._getter_list,
"revision_shortlog": self._getter_list,
"extid_get_from_target": self._getter_list,
"raw_extrinsic_metadata_get_by_ids": self._getter_list,
# Filter arguments
"directory_entry_get_by_path": self._getter_filtering_arguments,
"directory_get_entries": self._getter_filtering_arguments,
"directory_get_raw_manifest": self._getter_filtering_arguments,
"directory_ls": self._getter_filtering_arguments,
"raw_extrinsic_metadata_get_authorities": self._getter_filtering_arguments,
"snapshot_branch_get_by_name": self._getter_filtering_arguments,
"snapshot_count_branches": self._getter_filtering_arguments,
"snapshot_get_branches": self._getter_filtering_arguments,
"origin_snapshot_get_all": self._getter_filtering_arguments,
# Content functions that don't match common getter or adder suffixes
"content_add_metadata": _passthrough,
"content_missing_per_sha1": _passthrough,
"content_missing_per_sha1_git": _passthrough,
"content_update": _passthrough,
# These objects aren't maskable
"extid_get_from_extid": _passthrough,
"object_find_by_sha1_git": _passthrough,
"object_find_recent_references": _passthrough,
"metadata_authority_get": _passthrough,
"metadata_fetcher_get": _passthrough,
# Utility methods
"check_config": _passthrough,
"clear_buffers": _passthrough,
"flush": _passthrough,
"origin_count": _passthrough,
"refresh_stat_counters": _passthrough,
"stat_counters": _passthrough,
# For tests
"journal_writer": _passthrough,
}
self._methods_by_suffix = {
# These methods will never need do any masking
"add": _passthrough,
"missing": _passthrough,
# Partitions return PagedResults
"partition": self._getter_pagedresult,
# Getters return lists of optional objects
"get": self._getter_list,
"random": self._getter_random,
}
[docs]
def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]:
with masking_overhead_timer("content_get_data") as t:
with t.inner():
ret = self.storage.content_get_data(content)
if ret is None:
return None
if isinstance(content, dict) and "sha1_git" in content:
self._raise_if_masked_swhids(
[
ExtendedSWHID(
object_type=ExtendedObjectType.CONTENT,
object_id=content["sha1_git"],
)
]
)
else:
# We did not get the SWHID of the object as argument, so we need to
# hash the resulting content to check if its SWHID was masked.
self._raise_if_masked_swhids(
[
ExtendedSWHID(
object_type=ExtendedObjectType.CONTENT,
object_id=MultiHash.from_data(ret, ["sha1_git"]).digest()[
"sha1_git"
],
)
]
)
return ret
def _get_swhids_in_args(
self, method_name: str, parsed_args: Dict[str, Any]
) -> List[ExtendedSWHID]:
"""Extract SWHIDs from the parsed arguments of ``method_name``.
Arguments:
method_name: name of the called method
parsed_args: arguments of the method parsed with :func:`inspect.getcallargs`
"""
if method_name in ("directory_entry_get_by_path", "directory_ls"):
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=parsed_args["directory"],
)
]
elif method_name == "directory_get_entries":
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=parsed_args["directory_id"],
)
]
elif method_name.startswith("snapshot_"):
return [
ExtendedSWHID(
object_type=ExtendedObjectType.SNAPSHOT,
object_id=parsed_args["snapshot_id"],
)
]
elif method_name == "raw_extrinsic_metadata_get_authorities":
return [parsed_args["target"]]
elif method_name == "directory_get_raw_manifest":
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=object_id,
)
for object_id in parsed_args["directory_ids"]
]
elif method_name == "origin_snapshot_get_all":
return [Origin(url=parsed_args["origin_url"]).swhid()]
else:
raise ValueError(f"Cannot get swhid for arguments of method {method_name}")
def _getter_filtering_arguments(self, method_name: str):
"""Handles methods that should filter on their argument, instead of the
returned value. If the underlying storage returns :const:`None`, return
it, else, raise :exc:`MaskedObjectException` if the requested object is
masked"""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
signature = inspect.signature(getattr(StorageInterface, method_name))
bound_args = signature.bind(self, *args, **kwargs)
self._raise_if_masked_swhids(
self._get_swhids_in_args(method_name, bound_args.arguments)
)
return result
return newf
RANDOM_ATTEMPTS = 5
def _getter_random(self, method_name: str):
"""Handles methods returning a random object. Try
:const:`RANDOM_ATTEMPTS` times for a non-masked object, and return it,
else return :const:`None`."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
for _ in range(self.RANDOM_ATTEMPTS):
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
if not self._masked_result(method_name, result):
return result
return newf
def _getter_optional(self, method_name: str):
"""Handles methods returning an optional object: if the return value is
:const:`None`, return it, else, raise a :exc:`MaskedObjectException` if
the return value should be masked."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
self._raise_if_masked_result(method_name, result)
return result
return newf
def _raise_if_masked_result_in_list(
self, method_name: str, results: Iterable[TResult]
) -> List[TResult]:
"""Raise a :exc:`MaskedObjectException` if any non-:const:`None` object
in ``results`` is masked."""
result_swhids = set()
results = list(results)
for result in results:
if result is not None:
result_swhids.update(self._get_swhids_in_result(method_name, result))
if result_swhids:
self._raise_if_masked_swhids(list(result_swhids))
return results
def _getter_list(
self,
method_name: str,
):
"""Handle methods returning a list (or a generator) of optional objects,
raising :exc:`MaskedObjectException` for all the masked objects in the
batch."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
results = list(method(*args, **kwargs))
self._raise_if_masked_result_in_list(method_name, results)
return results
return newf
def _getter_pagedresult(self, method_name: str) -> Callable[..., PagedResult]:
"""Handle methods returning a :cls:`PagedResult`, raising
:exc:`MaskedObjectException` if some objects in the returned page are
masked."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs) -> PagedResult:
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
results = method(*args, **kwargs)
self._raise_if_masked_result_in_list(method_name, results.results)
return results
return newf
# Patching proxy feature set
TRevision = TypeVar("TRevision", Revision, Optional[Revision])
def _apply_revision_display_names(
self, revisions: List[TRevision]
) -> List[TRevision]:
emails = set()
for rev in revisions:
if (
rev is not None
and rev.author is not None
and rev.author.email # ignore None or empty email addresses
):
emails.add(rev.author.email)
if (
rev is not None
and rev.committer is not None
and rev.committer.email # ignore None or empty email addresses
):
emails.add(rev.committer.email)
with self._masking_query() as q:
display_names = q.display_name(list(emails))
# Short path for the common case
if not display_names:
return revisions
persons: Dict[Optional[bytes], Person] = {
email: Person.from_fullname(display_name)
for (email, display_name) in display_names.items()
}
return [
(
None
if revision is None
else attr.evolve(
revision,
author=(
revision.author
if revision.author is None
else persons.get(revision.author.email, revision.author)
),
committer=(
revision.committer
if revision.committer is None
else persons.get(revision.committer.email, revision.committer)
),
)
)
for revision in revisions
]
TRelease = TypeVar("TRelease", Release, Optional[Release])
def _apply_release_display_names(self, releases: List[TRelease]) -> List[TRelease]:
emails = set()
for rel in releases:
if (
rel is not None
and rel.author is not None
and rel.author.email # ignore None or empty email addresses
):
emails.add(rel.author.email)
with self._masking_query() as q:
display_names = q.display_name(list(emails))
# Short path for the common case
if not display_names:
return releases
persons: Dict[Optional[bytes], Person] = {
email: Person.from_fullname(display_name)
for (email, display_name) in display_names.items()
}
return [
(
None
if release is None
else attr.evolve(
release,
author=(
release.author
if release.author is None
else persons.get(release.author.email, release.author)
),
)
)
for release in releases
]
[docs]
def revision_get(
self, revision_ids: List[Sha1Git], ignore_displayname: bool = False
) -> List[Optional[Revision]]:
revisions = self.storage.revision_get(revision_ids)
self._raise_if_masked_result_in_list("revision_get", revisions)
return self._apply_revision_display_names(revisions)
[docs]
def revision_log(
self,
revisions: List[Sha1Git],
ignore_displayname: bool = False,
limit: Optional[int] = None,
) -> Iterable[Optional[Dict[str, Any]]]:
revision_batches = grouper(
self.storage.revision_log(revisions, limit=limit), BATCH_SIZE
)
yield from map(
Revision.to_dict,
itertools.chain.from_iterable(
self._apply_revision_display_names(
self._raise_if_masked_result_in_list(
"revision_log", list(map(Revision.from_dict, revision_batch))
)
)
for revision_batch in revision_batches
),
)
[docs]
def revision_get_partition(
self,
partition_id: int,
nb_partitions: int,
page_token: Optional[str] = None,
limit: int = 1000,
) -> PagedResult[Revision]:
page: PagedResult[Revision] = self.storage.revision_get_partition(
partition_id, nb_partitions, page_token, limit
)
return PagedResult(
results=self._apply_revision_display_names(
self._raise_if_masked_result_in_list(
"revision_get_parition", page.results
)
),
next_page_token=page.next_page_token,
)
[docs]
def release_get(
self, releases: List[Sha1Git], ignore_displayname: bool = False
) -> List[Optional[Release]]:
return self._apply_release_display_names(
self._raise_if_masked_result_in_list(
"release_get", self.storage.release_get(releases)
)
)
[docs]
def release_get_partition(
self,
partition_id: int,
nb_partitions: int,
page_token: Optional[str] = None,
limit: int = 1000,
) -> PagedResult[Release]:
page = self.storage.release_get_partition(
partition_id, nb_partitions, page_token, limit
)
return PagedResult(
results=self._apply_release_display_names(
self._raise_if_masked_result_in_list(
"release_get_partition", page.results
)
),
next_page_token=page.next_page_token,
)