Source code for swh.objstorage.replayer.replay

# Copyright (C) 2019-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import logging
from queue import Empty, Queue
import sys
from threading import Event, Thread
from time import time
from traceback import format_tb
from typing import Any, Callable, Dict, List, Optional, Tuple

from humanize import naturaldelta, naturalsize
import msgpack
import sentry_sdk

from swh.objstorage.interface import (
    CompositeObjId,
    ObjStorageInterface,
    objid_from_dict,
)

try:
    from systemd.daemon import notify
except ImportError:
    notify = None

from tenacity import (
    RetryCallState,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)
from tenacity.retry import retry_base

from swh.core.statsd import statsd
from swh.model.hashutil import MultiHash, hash_to_hex
from swh.model.model import SHA1_SIZE
from swh.objstorage.exc import Error, ObjNotFoundError

# import the factory module is needed to make tests work (get_objstorage is patched)
import swh.objstorage.factory as factory

logger = logging.getLogger(__name__)
REPORTER: Optional[Callable[[str, bytes], Any]] = None

CONTENT_OPERATIONS_METRIC = "swh_content_replayer_operations_total"
CONTENT_RETRY_METRIC = "swh_content_replayer_retries_total"
CONTENT_BYTES_METRIC = "swh_content_replayer_bytes"
CONTENT_DURATION_METRIC = "swh_content_replayer_duration_seconds"


[docs] class LengthMismatch(Exception): def __init__(self, expected, received): self.expected = expected self.received = received def __str__(self): return f"Length mismatch: received {self.received} != expected {self.expected}"
[docs] class HashMismatch(Exception): def __init__(self, expected, received): self.mismatched = {} self.matched = {} for algo, value in expected.items(): received_value = received.get(algo) if received_value != value: self.mismatched[algo] = (received_value, value) else: self.matched[algo] = value def __str__(self): return "\n".join( ["Hash Mismatch:"] + [ f" {algo}: {v[0].hex()} != expected {v[1].hex()}" for algo, v in self.mismatched.items() ] + ( ["Matched hashes:"] + [f" {algo}: {v.hex()}" for algo, v in self.matched] ) if self.matched else [] )
[docs] def format_obj_id(obj_id: CompositeObjId) -> str: return ";".join( ( "%s:%s" % (algo, hash_to_hex(hash)) for algo, hash in sorted(obj_id.items()) if hash ) )
[docs] def hex_obj_id(obj_id: CompositeObjId) -> Dict[str, str]: return {algo: hash_to_hex(hash) for algo, hash in obj_id.items() if hash}
[docs] def logger_debug_obj_id(msg, args, **kwargs): if logger.isEnabledFor(logging.DEBUG): if sys.version_info >= (3, 8): # Ignore this helper in line/function calculation kwargs = {**kwargs, "stacklevel": kwargs.get("stacklevel", 1) + 1} logger.debug(msg, {**args, "obj_id": format_obj_id(args["obj_id"])}, **kwargs)
[docs] def is_hash_in_bytearray(hash_, array, nb_hashes, hash_size=SHA1_SIZE): """ Checks if the given hash is in the provided `array`. The array must be a *sorted* list of sha1 hashes, and contain `nb_hashes` hashes (so its size must by `nb_hashes*hash_size` bytes). Args: hash_ (bytes): the hash to look for array (bytes): a sorted concatenated array of hashes (may be of any type supporting slice indexing, eg. :class:`mmap.mmap`) nb_hashes (int): number of hashes in the array hash_size (int): size of a hash (defaults to 20, for SHA1) Example: >>> import os >>> hash1 = os.urandom(20) >>> hash2 = os.urandom(20) >>> hash3 = os.urandom(20) >>> array = b''.join(sorted([hash1, hash2])) >>> is_hash_in_bytearray(hash1, array, 2) True >>> is_hash_in_bytearray(hash2, array, 2) True >>> is_hash_in_bytearray(hash3, array, 2) False """ if len(hash_) != hash_size: raise ValueError("hash_ does not match the provided hash_size.") def get_hash(position): return array[position * hash_size : (position + 1) * hash_size] # Regular dichotomy: left = 0 right = nb_hashes while left < right - 1: middle = int((right + left) / 2) pivot = get_hash(middle) if pivot == hash_: return True elif pivot < hash_: left = middle else: right = middle return get_hash(left) == hash_
[docs] class ReplayError(Exception): """An error occurred during the replay of an object""" def __init__(self, *, obj_id: CompositeObjId, exc) -> None: self.obj_id = obj_id self.exc = exc def __str__(self) -> str: return "ReplayError(%s, %r, %s)" % ( format_obj_id(self.obj_id), self.exc, format_tb(self.exc.__traceback__), )
[docs] def log_replay_retry( retry_state: RetryCallState, sleep: Optional[float] = None, last_result: Any = None ) -> None: """Log a retry of the content replayer""" assert retry_state.outcome is not None exc = retry_state.outcome.exception() assert isinstance(exc, ReplayError) assert retry_state.fn is not None operation = retry_state.fn.__name__ logger_debug_obj_id( "Retry operation %(operation)s on %(obj_id)s: %(exc)s", { "operation": operation, "obj_id": exc.obj_id, "exc": str(exc.exc), }, )
[docs] def log_replay_error( obj_id: CompositeObjId, exc: Exception, operation: str, retries: int ) -> None: with sentry_sdk.push_scope() as scope: scope.set_tag("operation", operation) scope.set_extra("obj_id", hex_obj_id(obj_id)) sentry_sdk.capture_exception(exc) error_context = { "obj_id": format_obj_id(obj_id), "operation": operation, "exc": str(exc), "retries": retries, } logger.error( "Failed operation %(operation)s on %(obj_id)s after %(retries)s" " retries; last exception: %(exc)s", error_context, ) # if we have a global error (redis) reporter if REPORTER is not None: oid = f"blob:{format_obj_id(obj_id)}" msg = msgpack.dumps(error_context) REPORTER(oid, msg)
[docs] def retry_error_callback(retry_state: RetryCallState) -> None: """Log a replay error to sentry""" assert retry_state.outcome exc = retry_state.outcome.exception() assert isinstance(exc, ReplayError) assert retry_state.fn operation = retry_state.fn.__name__ log_replay_error( obj_id=exc.obj_id, exc=exc.exc, operation=operation, retries=retry_state.attempt_number, ) raise exc
CONTENT_REPLAY_RETRIES = 3
[docs] class retry_log_if_success(retry_base): """Log in statsd the number of attempts required to succeed""" def __call__(self, retry_state: RetryCallState): assert retry_state.outcome if not retry_state.outcome.failed: assert retry_state.fn statsd.increment( CONTENT_RETRY_METRIC, tags={ "operation": retry_state.fn.__name__, "attempt": str(retry_state.attempt_number), }, ) return False
content_replay_retry = retry( retry=retry_if_exception_type(ReplayError) | retry_log_if_success(), stop=stop_after_attempt(CONTENT_REPLAY_RETRIES), wait=wait_random_exponential(multiplier=1, max=60), before_sleep=log_replay_retry, retry_error_callback=retry_error_callback, )
[docs] @content_replay_retry def get_object(objstorage: ObjStorageInterface, obj_id: CompositeObjId) -> bytes: try: with statsd.timed(CONTENT_DURATION_METRIC, tags={"request": "get"}): obj = objstorage.get(obj_id) logger_debug_obj_id("retrieved %(obj_id)s", {"obj_id": obj_id}) return obj except ObjNotFoundError: logger.error( "Failed to retrieve %(obj_id)s: object not found", {"obj_id": format_obj_id(obj_id)}, ) raise except Exception as exc: raise ReplayError(obj_id=obj_id, exc=exc) from None
[docs] def check_hashes(obj: bytes, obj_id: CompositeObjId): h = MultiHash.from_data(obj, hash_names=obj_id.keys()) computed = h.digest() if computed != obj_id: exc = HashMismatch(obj_id, computed) log_replay_error(obj_id=obj_id, exc=exc, operation="check_hashes", retries=1) raise exc
[docs] @content_replay_retry def put_object(objstorage: ObjStorageInterface, obj_id: CompositeObjId, obj: bytes): try: logger_debug_obj_id("putting %(obj_id)s", {"obj_id": obj_id}) with statsd.timed(CONTENT_DURATION_METRIC, tags={"request": "put"}): logger_debug_obj_id("storing %(obj_id)s", {"obj_id": obj_id}) objstorage.add(obj, obj_id, check_presence=False) logger_debug_obj_id("stored %(obj_id)s", {"obj_id": obj_id}) except Exception as exc: logger.error( "putting %(obj_id)s failed: %(exc)r", {"obj_id": format_obj_id(obj_id), "exc": exc}, ) raise ReplayError(obj_id=obj_id, exc=exc) from None
[docs] def copy_object( obj_id: CompositeObjId, obj_len: int, src: ObjStorageInterface, dst: ObjStorageInterface, check_src_hashes: bool = False, ) -> int: obj = get_object(src, obj_id) if obj is not None: if len(obj) != obj_len: raise LengthMismatch(obj_len, len(obj)) if check_src_hashes: check_hashes(obj, obj_id) put_object(dst, obj_id, obj) statsd.increment(CONTENT_BYTES_METRIC, len(obj)) return len(obj) return 0
[docs] @content_replay_retry def obj_in_objstorage(obj_id: CompositeObjId, dst: ObjStorageInterface) -> bool: """Check if an object is already in an objstorage, tenaciously""" try: return obj_id in dst except Exception as exc: raise ReplayError(obj_id=obj_id, exc=exc) from None
[docs] class ContentReplayer: def __init__( self, src: Dict[str, Any], dst: Dict[str, Any], exclude_fn: Optional[Callable[[Dict[str, Any]], bool]] = None, check_dst: bool = True, check_obj: bool = False, check_src_hashes: bool = False, concurrency: int = 16, ): """Helper class that takes a list of records from Kafka (see :py:func:`swh.journal.client.JournalClient.process`) and copies them from the `src` objstorage to the `dst` objstorage, if: * `obj['status']` is `'visible'` * `exclude_fn(obj)` is `False` (if `exclude_fn` is provided) * `CompositeObjId(**obj) not in dst` (if `check_dst` is True) Args: src: An object storage configuration dict (see :py:func:`swh.objstorage.get_objstorage`) dst: An object storage configuration dict (see :py:func:`swh.objstorage.get_objstorage`) exclude_fn: Determines whether an object should be copied. check_dst: Determines whether we should check the destination objstorage before copying. check_obj: If check_dst is true, determines whether we should check the existing object in the destination objstorage is valid; if not, put the replayed object. check_src_hashes: Checks the object before sending it to the dst objstorage. concurrency: Number of worker threads doing the replication process (retrieve, check, store). See swh/objstorage/replayer/tests/test_replay.py for usage examples. """ self.src_cfg = src self.dst_cfg = dst self.exclude_fn = exclude_fn self.check_dst = check_dst self.check_obj = check_obj self.check_src_hashes = check_src_hashes self.concurrency = concurrency self.obj_queue: Queue = Queue() self.return_queue: Queue = Queue() self.stop_event = Event() self.workers = [Thread(target=self._worker) for i in range(self.concurrency)] for w in self.workers: w.start() def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.stop()
[docs] def stop(self): """Stop replayer's worker threads""" self.stop_event.set() for worker in self.workers: worker.join()
def _copy_object( self, obj: Dict[str, Any], src: ObjStorageInterface, dst: ObjStorageInterface ): obj_id = objid_from_dict(obj) if not obj_id: raise ValueError( "Object is missing the keys expected in CompositeObjId", obj ) logger_debug_obj_id("Starting copy object %(obj_id)s", {"obj_id": obj_id}) decision = None copied_bytes = 0 tags = {} if obj["status"] != "visible": logger_debug_obj_id( "skipped %(obj_id)s (status=%(status)s)", {"obj_id": obj_id, "status": obj["status"]}, ) decision = "skipped" tags["status"] = obj["status"] elif self.exclude_fn and self.exclude_fn(obj): logger_debug_obj_id( "skipped %(obj_id)s (manually excluded)", {"obj_id": obj_id} ) decision = "excluded" elif self.check_dst and obj_in_objstorage(obj_id, dst): decision = "in_dst" if self.check_obj: try: dst.check(obj_id) except Error: logger.info("invalid object found in dst %s", format_obj_id(obj_id)) decision = None tags["status"] = "invalid_in_dst" if decision is None: try: copied_bytes = copy_object( obj_id, obj_len=obj["length"], src=src, dst=dst, check_src_hashes=self.check_src_hashes, ) except ObjNotFoundError: logger_debug_obj_id("not found %(obj_id)s", {"obj_id": obj_id}) decision = "not_found" if not self.check_dst and obj_in_objstorage(obj_id, dst): tags["status"] = "found_in_dst" except LengthMismatch as exc: logger.info("length mismatch %s", format_obj_id(obj_id), exc_info=exc) decision = "length_mismatch" if not self.check_dst and obj_in_objstorage(obj_id, dst): tags["status"] = "found_in_dst" except HashMismatch as exc: logger.info("hash mismatch %s", format_obj_id(obj_id), exc_info=exc) decision = "hash_mismatch" except Exception as exc: logger.info("failed %s", format_obj_id(obj_id), exc_info=exc) decision = "failed" else: if copied_bytes is None: logger_debug_obj_id("failed %(obj_id)s (None)", {"obj_id": obj_id}) decision = "failed" else: logger_debug_obj_id( "copied %(obj_id)s (%(bytes)d)", {"obj_id": obj_id, "bytes": copied_bytes}, ) decision = "copied" tags["decision"] = decision statsd.increment(CONTENT_OPERATIONS_METRIC, tags=tags) return decision, copied_bytes def _worker(self): src = factory.get_objstorage(**self.src_cfg) dst = factory.get_objstorage(**self.dst_cfg) while not self.stop_event.is_set(): try: obj = self.obj_queue.get(timeout=1) except Empty: continue try: decision, nbytes = self._copy_object(obj, src=src, dst=dst) except Exception as exc: self.return_queue.put(("error", 0, exc)) else: self.return_queue.put((decision, nbytes, None))
[docs] def replay( self, all_objects: Dict[str, List[dict]], ): vol = 0 stats = dict.fromkeys( [ "skipped", "excluded", "not_found", "failed", "copied", "in_dst", "hash_mismatch", "length_mismatch", ], 0, ) t0 = time() nobjs = 0 for object_type, objects in all_objects.items(): if object_type != "content": logger.warning( "Received a series of %s, this should not happen", object_type ) continue for obj in objects: self.obj_queue.put(obj) nobjs += 1 logger.debug("Waiting for the obj queue to be processed") results: List[Tuple[str, int, Optional[Exception]]] = [] while (not self.stop_event.is_set()) and (len(results) < nobjs): try: result = self.return_queue.get(timeout=1) except Empty: continue else: results.append(result) logger.debug("Checking results") for decision, nbytes, exc in results: if exc: # XXX this should not happen, so it is probably wrong... raise exc else: if nbytes is not None: vol += nbytes stats[decision] += 1 dt = time() - t0 logger.info( "processed %s content objects (%s) in %s " "(%.1f obj/sec, %s/sec) " "- %d copied - %d in dst - %d skipped " "- %d excluded - %d not found - %d failed", nobjs, naturalsize(vol), naturaldelta(dt), nobjs / dt, naturalsize(vol / dt), stats["copied"], stats["in_dst"], stats["skipped"], stats["excluded"], stats["not_found"], stats["failed"], ) if notify: notify("WATCHDOG=1")