# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
import functools
import inspect
import itertools
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
TypeVar,
Union,
)
import warnings
import attr
import psycopg2.pool
from swh.core.utils import grouper
from swh.model.hashutil import MultiHash
from swh.model.model import Origin, Person, Release, Revision, Sha1Git
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.storage import get_storage
from swh.storage.exc import MaskedObjectException
from swh.storage.interface import HashDict, PagedResult, Sha1, StorageInterface, TResult
from swh.storage.metrics import DifferentialTimer
from swh.storage.proxies.masking.db import MaskedStatus
from .db import MaskingQuery
BATCH_SIZE = 1024
MASKING_OVERHEAD_METRIC = "swh_storage_masking_overhead_seconds"
[docs]
def get_datastore(cls, db=None, masking_db=None, **kwargs):
assert cls in ("postgresql", "masking")
from .db import MaskingAdmin
if db is None:
db = masking_db
return MaskingAdmin.connect(db)
[docs]
def masking_overhead_timer(method_name: str) -> DifferentialTimer:
"""Return a properly setup DifferentialTimer for ``method_name`` of the storage"""
return DifferentialTimer(MASKING_OVERHEAD_METRIC, tags={"endpoint": method_name})
[docs]
class MaskingProxyStorage:
"""Masking storage proxy
This proxy can return modified objects or stop them from being retrieved at
all.
It uses a specific PostgreSQL database (which for now is colocated with the
swh.storage PostgreSQL database), the access to which is implemented in the
:mod:`.db` submodule.
Sample configuration
.. code-block: yaml
storage:
cls: masking
db: 'dbname=swh-storage'
max_pool_conns: 10
storage:
- cls: remote
url: http://storage.internal.staging.swh.network:5002/
"""
def __init__(
self,
storage: Union[Dict, StorageInterface],
db: Optional[str] = None,
masking_db: Optional[str] = None,
min_pool_conns: int = 1,
max_pool_conns: int = 5,
):
if db is None:
assert masking_db is not None
warnings.warn(
"'masking_db' field in the masking storage configuration "
"was renamed 'db' field",
DeprecationWarning,
)
db = masking_db
self.storage: StorageInterface = (
get_storage(**storage) if isinstance(storage, dict) else storage
)
self._masking_pool = psycopg2.pool.ThreadedConnectionPool(
min_pool_conns, max_pool_conns, db
)
# Generate the method dictionaries once per instantiation, instead of
# doing it on every (first) __getattr__ call.
self._gen_method_dicts()
@contextmanager
def _masking_query(self) -> Iterator[MaskingQuery]:
ret = None
try:
ret = MaskingQuery.from_pool(self._masking_pool)
yield ret
finally:
if ret:
ret.put_conn()
@staticmethod
def _get_swhids_in_result(method_name: str, result: Any) -> List[ExtendedSWHID]:
if result is None:
raise TypeError(f"Filtering of Nones missing in {method_name}")
result_type = getattr(result, "object_type", None)
if result_type == "raw_extrinsic_metadata":
# Raw Extrinsic Metadata have a swhid, but we also mask them if the target is masked
return [result.swhid(), result.target]
if hasattr(result, "swhid"):
swhid = result.swhid()
if hasattr(swhid, "to_extended"):
swhid = swhid.to_extended()
assert isinstance(
swhid, ExtendedSWHID
), f"{method_name} returned object with unexpected swhid() method return type"
return [swhid]
if result_type:
if result_type == "extid":
return [result.target.to_extended()]
if result_type.value.startswith("origin_visit"):
return [Origin(url=result.origin).swhid()]
if (
object_type := method_name.removesuffix("_get_random").removesuffix(
"_get_id_partition"
)
) != method_name:
# Returns bare sha1_gits
assert isinstance(result, bytes), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType[object_type.upper()],
object_id=result,
)
]
elif method_name == "revision_log":
# returns dicts of revisions
assert (
isinstance(result, dict) and "id" in result
), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType.REVISION, object_id=result["id"]
)
]
elif method_name == "revision_shortlog":
# Returns tuples (revision.id, revision.parents)
assert isinstance(
result[0], bytes
), f"{method_name} returned unexpected type"
return [
ExtendedSWHID(
object_type=ExtendedObjectType.REVISION, object_id=result[0]
)
]
elif method_name == "origin_get_by_sha1":
# Returns origin dicts, because why not
assert (
isinstance(result, dict) and "url" in result
), f"{method_name} returned unexpected type"
return [Origin(url=result["url"]).swhid()]
elif method_name == "origin_visit_get_with_statuses":
# Returns an OriginVisitWithStatuses
return [Origin(url=result.visit.origin).swhid()]
elif method_name == "snapshot_get":
# Returns a snapshot dict
return [
ExtendedSWHID(
object_type=ExtendedObjectType.SNAPSHOT, object_id=result["id"]
)
]
raise ValueError(f"Cannot get swhid for result of method {method_name}")
def _masked_result(
self, method_name: str, result: Any
) -> Optional[Dict[ExtendedSWHID, List[MaskedStatus]]]:
"""Find the SWHIDs of the ``result`` object, and check if any of them is
masked, returning the associated masking information."""
with self._masking_query() as q:
return q.swhids_are_masked(self._get_swhids_in_result(method_name, result))
def _raise_if_masked_result(self, method_name: str, result: Any) -> None:
"""Raise a :exc:`MaskedObjectException` if ``result`` is masked."""
masked = self._masked_result(method_name, result)
if masked:
raise MaskedObjectException(masked)
def _raise_if_masked_swhids(self, swhids: List[ExtendedSWHID]) -> None:
"""Raise a :exc:`MaskedObjectException` if any SWHID is masked."""
with self._masking_query() as q:
masked = q.swhids_are_masked(swhids)
if masked:
raise MaskedObjectException(masked)
def __getattr__(self, key):
method = None
if key in self._methods_by_name:
method = self._methods_by_name[key](key)
else:
suffix = key.rsplit("_", 1)[-1]
if suffix in self._methods_by_suffix:
method = self._methods_by_suffix[suffix](key)
if method:
# Avoid going through __getattr__ again next time
setattr(self, key, method)
return method
# Raise a NotImplementedError to make sure we don't forget to add
# masking to any new storage functions
raise NotImplementedError(key)
def _gen_method_dicts(self):
"""Generate the :attr:`_methods_by_name` and :attr:`_methods_by_suffix`
used by :meth:`__getattr__`"""
_passthrough = functools.partial(getattr, self.storage)
self._methods_by_name = {
# Returns a single object
"snapshot_get": self._getter_optional,
"origin_visit_find_by_date": self._getter_optional,
"origin_visit_get_by": self._getter_optional,
"origin_visit_status_get_latest": self._getter_optional,
"origin_visit_get_latest": self._getter_optional,
# Returns a PagedResult
"origin_list": self._getter_pagedresult,
"origin_visit_get": self._getter_pagedresult,
"origin_search": self._getter_pagedresult,
"raw_extrinsic_metadata_get": self._getter_pagedresult,
"origin_visit_get_with_statuses": self._getter_pagedresult,
"origin_visit_status_get": self._getter_pagedresult,
# Returns a list of (optional) objects
"origin_get_by_sha1": self._getter_list,
"content_find": self._getter_list,
"skipped_content_find": self._getter_list,
"revision_shortlog": self._getter_list,
"extid_get_from_target": self._getter_list,
"raw_extrinsic_metadata_get_by_ids": self._getter_list,
# Filter arguments
"directory_entry_get_by_path": self._getter_filtering_arguments,
"directory_get_entries": self._getter_filtering_arguments,
"directory_get_raw_manifest": self._getter_filtering_arguments,
"directory_ls": self._getter_filtering_arguments,
"raw_extrinsic_metadata_get_authorities": self._getter_filtering_arguments,
"snapshot_branch_get_by_name": self._getter_filtering_arguments,
"snapshot_count_branches": self._getter_filtering_arguments,
"snapshot_get_branches": self._getter_filtering_arguments,
"origin_snapshot_get_all": self._getter_filtering_arguments,
# Content functions that don't match common getter or adder suffixes
"content_add_metadata": _passthrough,
"content_missing_per_sha1": _passthrough,
"content_missing_per_sha1_git": _passthrough,
"content_update": _passthrough,
# These objects aren't maskable
"extid_get_from_extid": _passthrough,
"object_find_by_sha1_git": _passthrough,
"object_find_recent_references": _passthrough,
"metadata_authority_get": _passthrough,
"metadata_fetcher_get": _passthrough,
# Utility methods
"check_config": _passthrough,
"clear_buffers": _passthrough,
"flush": _passthrough,
"origin_count": _passthrough,
"refresh_stat_counters": _passthrough,
"stat_counters": _passthrough,
# For tests
"journal_writer": _passthrough,
}
self._methods_by_suffix = {
# These methods will never need do any masking
"add": _passthrough,
"missing": _passthrough,
# Partitions return PagedResults
"partition": self._getter_pagedresult,
# Getters return lists of optional objects
"get": self._getter_list,
"random": self._getter_random,
}
[docs]
def content_get_data(self, content: Union[HashDict, Sha1]) -> Optional[bytes]:
with masking_overhead_timer("content_get_data") as t:
with t.inner():
ret = self.storage.content_get_data(content)
if ret is None:
return None
if isinstance(content, dict) and "sha1_git" in content:
self._raise_if_masked_swhids(
[
ExtendedSWHID(
object_type=ExtendedObjectType.CONTENT,
object_id=content["sha1_git"],
)
]
)
else:
# We did not get the SWHID of the object as argument, so we need to
# hash the resulting content to check if its SWHID was masked.
self._raise_if_masked_swhids(
[
ExtendedSWHID(
object_type=ExtendedObjectType.CONTENT,
object_id=MultiHash.from_data(ret, ["sha1_git"]).digest()[
"sha1_git"
],
)
]
)
return ret
def _get_swhids_in_args(
self, method_name: str, parsed_args: Dict[str, Any]
) -> List[ExtendedSWHID]:
"""Extract SWHIDs from the parsed arguments of ``method_name``.
Arguments:
method_name: name of the called method
parsed_args: arguments of the method parsed with :func:`inspect.getcallargs`
"""
if method_name in ("directory_entry_get_by_path", "directory_ls"):
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=parsed_args["directory"],
)
]
elif method_name == "directory_get_entries":
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=parsed_args["directory_id"],
)
]
elif method_name.startswith("snapshot_"):
return [
ExtendedSWHID(
object_type=ExtendedObjectType.SNAPSHOT,
object_id=parsed_args["snapshot_id"],
)
]
elif method_name == "raw_extrinsic_metadata_get_authorities":
return [parsed_args["target"]]
elif method_name == "directory_get_raw_manifest":
return [
ExtendedSWHID(
object_type=ExtendedObjectType.DIRECTORY,
object_id=object_id,
)
for object_id in parsed_args["directory_ids"]
]
elif method_name == "origin_snapshot_get_all":
return [Origin(url=parsed_args["origin_url"]).swhid()]
else:
raise ValueError(f"Cannot get swhid for arguments of method {method_name}")
def _getter_filtering_arguments(self, method_name: str):
"""Handles methods that should filter on their argument, instead of the
returned value. If the underlying storage returns :const:`None`, return
it, else, raise :exc:`MaskedObjectException` if the requested object is
masked"""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
signature = inspect.signature(getattr(StorageInterface, method_name))
bound_args = signature.bind(self, *args, **kwargs)
self._raise_if_masked_swhids(
self._get_swhids_in_args(method_name, bound_args.arguments)
)
return result
return newf
RANDOM_ATTEMPTS = 5
def _getter_random(self, method_name: str):
"""Handles methods returning a random object. Try
:const:`RANDOM_ATTEMPTS` times for a non-masked object, and return it,
else return :const:`None`."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
for _ in range(self.RANDOM_ATTEMPTS):
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
if not self._masked_result(method_name, result):
return result
return newf
def _getter_optional(self, method_name: str):
"""Handles methods returning an optional object: if the return value is
:const:`None`, return it, else, raise a :exc:`MaskedObjectException` if
the return value should be masked."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
result = method(*args, **kwargs)
if result is None:
return None
self._raise_if_masked_result(method_name, result)
return result
return newf
def _raise_if_masked_result_in_list(
self, method_name: str, results: Iterable[TResult]
) -> List[TResult]:
"""Raise a :exc:`MaskedObjectException` if any non-:const:`None` object
in ``results`` is masked."""
result_swhids = set()
results = list(results)
for result in results:
if result is not None:
result_swhids.update(self._get_swhids_in_result(method_name, result))
if result_swhids:
self._raise_if_masked_swhids(list(result_swhids))
return results
def _getter_list(
self,
method_name: str,
):
"""Handle methods returning a list (or a generator) of optional objects,
raising :exc:`MaskedObjectException` for all the masked objects in the
batch."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs):
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
results = list(method(*args, **kwargs))
self._raise_if_masked_result_in_list(method_name, results)
return results
return newf
def _getter_pagedresult(self, method_name: str) -> Callable[..., PagedResult]:
"""Handle methods returning a :cls:`PagedResult`, raising
:exc:`MaskedObjectException` if some objects in the returned page are
masked."""
@functools.wraps(getattr(self.storage, method_name))
def newf(*args, **kwargs) -> PagedResult:
with masking_overhead_timer(method_name) as t:
method = getattr(self.storage, method_name)
with t.inner():
results = method(*args, **kwargs)
self._raise_if_masked_result_in_list(method_name, results.results)
return results
return newf
# Patching proxy feature set
TRevision = TypeVar("TRevision", Revision, Optional[Revision])
def _apply_revision_display_names(
self, revisions: List[TRevision]
) -> List[TRevision]:
emails = set()
for rev in revisions:
if (
rev is not None
and rev.author is not None
and rev.author.email # ignore None or empty email addresses
):
emails.add(rev.author.email)
if (
rev is not None
and rev.committer is not None
and rev.committer.email # ignore None or empty email addresses
):
emails.add(rev.committer.email)
with self._masking_query() as q:
display_names = q.display_name(list(emails))
# Short path for the common case
if not display_names:
return revisions
persons: Dict[Optional[bytes], Person] = {
email: Person.from_fullname(display_name)
for (email, display_name) in display_names.items()
}
return [
(
None
if revision is None
else attr.evolve(
revision,
author=(
revision.author
if revision.author is None
else persons.get(revision.author.email, revision.author)
),
committer=(
revision.committer
if revision.committer is None
else persons.get(revision.committer.email, revision.committer)
),
)
)
for revision in revisions
]
TRelease = TypeVar("TRelease", Release, Optional[Release])
def _apply_release_display_names(self, releases: List[TRelease]) -> List[TRelease]:
emails = set()
for rel in releases:
if (
rel is not None
and rel.author is not None
and rel.author.email # ignore None or empty email addresses
):
emails.add(rel.author.email)
with self._masking_query() as q:
display_names = q.display_name(list(emails))
# Short path for the common case
if not display_names:
return releases
persons: Dict[Optional[bytes], Person] = {
email: Person.from_fullname(display_name)
for (email, display_name) in display_names.items()
}
return [
(
None
if release is None
else attr.evolve(
release,
author=(
release.author
if release.author is None
else persons.get(release.author.email, release.author)
),
)
)
for release in releases
]
[docs]
def revision_get(
self, revision_ids: List[Sha1Git], ignore_displayname: bool = False
) -> List[Optional[Revision]]:
revisions = self.storage.revision_get(revision_ids)
self._raise_if_masked_result_in_list("revision_get", revisions)
return self._apply_revision_display_names(revisions)
[docs]
def revision_log(
self,
revisions: List[Sha1Git],
ignore_displayname: bool = False,
limit: Optional[int] = None,
) -> Iterable[Optional[Dict[str, Any]]]:
revision_batches = grouper(
self.storage.revision_log(revisions, limit=limit), BATCH_SIZE
)
yield from map(
Revision.to_dict,
itertools.chain.from_iterable(
self._apply_revision_display_names(
self._raise_if_masked_result_in_list(
"revision_log", list(map(Revision.from_dict, revision_batch))
)
)
for revision_batch in revision_batches
),
)
[docs]
def revision_get_partition(
self,
partition_id: int,
nb_partitions: int,
page_token: Optional[str] = None,
limit: int = 1000,
) -> PagedResult[Revision]:
page: PagedResult[Revision] = self.storage.revision_get_partition(
partition_id, nb_partitions, page_token, limit
)
return PagedResult(
results=self._apply_revision_display_names(
self._raise_if_masked_result_in_list(
"revision_get_parition", page.results
)
),
next_page_token=page.next_page_token,
)
[docs]
def release_get(
self, releases: List[Sha1Git], ignore_displayname: bool = False
) -> List[Optional[Release]]:
return self._apply_release_display_names(
self._raise_if_masked_result_in_list(
"release_get", self.storage.release_get(releases)
)
)
[docs]
def release_get_partition(
self,
partition_id: int,
nb_partitions: int,
page_token: Optional[str] = None,
limit: int = 1000,
) -> PagedResult[Release]:
page = self.storage.release_get_partition(
partition_id, nb_partitions, page_token, limit
)
return PagedResult(
results=self._apply_release_display_names(
self._raise_if_masked_result_in_list(
"release_get_partition", page.results
)
),
next_page_token=page.next_page_token,
)