# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections.abc import Mapping
from functools import partial
import hashlib
import itertools
import operator
from typing import (
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
cast,
)
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.storage.interface import StorageInterface
T = TypeVar("T")
C = TypeVar("C", covariant=True)
[docs]
def iter_swhids_grouped_by_type(
swhids: Iterable[ExtendedSWHID],
*,
handlers: Mapping[ExtendedObjectType, Callable[[C], Iterable[T]]],
chunker: Optional[Callable[[Collection[ExtendedSWHID]], Iterable[C]]] = None,
) -> Iterable[T]:
"""Work on a iterable of SWHIDs grouped by their type, running a different
handler for each type.
The object types will be in the same order as in ``handlers``.
Arguments:
swhids: an iterable over some SWHIDs
handlers: a dictionary mapping each object type to an handler, taking
a collection of swhids and returning an iterable
chunker: an optional function to split the SWHIDs of same
object type into multiple “chunks”. It can also transform
the iterable into a more convenient collection.
Returns: an iterable over the handlers’ results
"""
def _default_chunker(it: Collection[ExtendedSWHID]) -> Iterable[C]:
yield cast(C, it)
chunker = chunker or _default_chunker
# groupby() splits consecutive groups, so we need to order the list first
ordering: Dict[ExtendedObjectType, int] = {
object_type: order for order, object_type in enumerate(handlers.keys())
}
def key(swhid: ExtendedSWHID) -> int:
return ordering[swhid.object_type]
sorted_swhids = sorted(swhids, key=key)
# Now we can use itertools.groupby()
for object_type, grouped_swhids in itertools.groupby(
sorted_swhids, key=operator.attrgetter("object_type")
):
for chunk in chunker(list(grouped_swhids)):
yield from handlers[object_type](chunk)
def _filter_missing_contents(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
missing_object_ids = set(
storage.content_missing_per_sha1_git(list(requested_object_ids))
)
yield from (
ExtendedSWHID(object_type=ExtendedObjectType.CONTENT, object_id=object_id)
for object_id in requested_object_ids - missing_object_ids
)
def _filter_missing_directories(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
missing_object_ids = set(storage.directory_missing(list(requested_object_ids)))
yield from (
ExtendedSWHID(object_type=ExtendedObjectType.DIRECTORY, object_id=object_id)
for object_id in requested_object_ids - missing_object_ids
)
def _filter_missing_revisions(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
missing_object_ids = set(storage.revision_missing(list(requested_object_ids)))
yield from (
ExtendedSWHID(object_type=ExtendedObjectType.REVISION, object_id=object_id)
for object_id in requested_object_ids - missing_object_ids
)
def _filter_missing_releases(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
missing_object_ids = set(storage.release_missing(list(requested_object_ids)))
yield from (
ExtendedSWHID(object_type=ExtendedObjectType.RELEASE, object_id=object_id)
for object_id in requested_object_ids - missing_object_ids
)
def _filter_missing_snapshots(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
missing_object_ids = set(storage.snapshot_missing(list(requested_object_ids)))
yield from (
ExtendedSWHID(object_type=ExtendedObjectType.SNAPSHOT, object_id=object_id)
for object_id in requested_object_ids - missing_object_ids
)
def _filter_missing_origins(
storage: StorageInterface, requested_object_ids: Set[bytes]
) -> Iterable[ExtendedSWHID]:
# XXX: We should add a better method in swh.storage
yield from (
ExtendedSWHID(
object_type=ExtendedObjectType.ORIGIN,
object_id=hashlib.sha1(d["url"].encode("utf-8")).digest(),
)
for d in storage.origin_get_by_sha1(list(requested_object_ids))
if d is not None
)
[docs]
def filter_objects_missing_from_storage(
storage: StorageInterface, swhids: Iterable[ExtendedSWHID]
) -> List[ExtendedSWHID]:
def chunker(swhids: Iterable[ExtendedSWHID]) -> Iterable[Set[bytes]]:
yield {swhid.object_id for swhid in swhids}
handlers: Dict[
ExtendedObjectType, Callable[[set[bytes]], Iterable[ExtendedSWHID]]
] = {
ExtendedObjectType.CONTENT: partial(_filter_missing_contents, storage),
ExtendedObjectType.DIRECTORY: partial(_filter_missing_directories, storage),
ExtendedObjectType.REVISION: partial(_filter_missing_revisions, storage),
ExtendedObjectType.RELEASE: partial(_filter_missing_releases, storage),
ExtendedObjectType.SNAPSHOT: partial(_filter_missing_snapshots, storage),
ExtendedObjectType.ORIGIN: partial(_filter_missing_origins, storage),
}
return list(iter_swhids_grouped_by_type(swhids, handlers=handlers, chunker=chunker))
[docs]
def get_filtered_objects(
storage: StorageInterface,
get_objects: Callable[[int], Collection[ExtendedSWHID]],
max_results: int,
) -> Collection[ExtendedSWHID]:
"""Call `get_objects(limit)` filtering out results with
`filter_objects_missing_from_storage`.
If some objects were filtered, call the function again with an increasing
limit until `max_results` objects are returned.
"""
limit = max_results
while True:
results = get_objects(limit)
filtered_results = filter_objects_missing_from_storage(storage, results)
filtered_count = len(results) - len(filtered_results)
if len(filtered_results) >= max_results:
return filtered_results[:max_results]
elif len(results) >= limit and filtered_count > 0:
# Some results have been filtered out and the initial call has
# reached the object limit, which means that we might have missed
# some extra entries. We need to increase the limit, at least to
# `limit + filtered_count`, doubling that will reach an endpoint
# faster
limit = 2 * (limit + filtered_count)
else:
return filtered_results