# Copyright (C) 2015-2025 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
import bz2
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
import functools
import lzma
from typing import (
Callable,
Dict,
Iterable,
Iterator,
Literal,
Optional,
Protocol,
Tuple,
)
import zlib
from swh.core import statsd
from swh.model.hashutil import HashDict, MultiHash, hash_to_hex
from swh.objstorage.constants import LiteralPrimaryHash
from swh.objstorage.exc import ObjCorruptedError, ObjNotFoundError
from swh.objstorage.interface import ObjStorageInterface, objid_from_dict
DURATION_METRICS = "swh_objstorage_request_duration_seconds"
[docs]
def timed(f):
"""A simple decorator used to add statsd probes on main ObjStorage methods
(add, get and __contains__)
"""
@functools.wraps(f)
def w(self, *a, **kw):
with statsd.statsd.timed(
DURATION_METRICS,
tags={"endpoint": f.__name__, "name": self.name},
):
return f(self, *a, **kw)
w._timed = True
w._f = f
return w
[docs]
def objid_to_default_hex(obj_id: HashDict, algo: LiteralPrimaryHash) -> str:
"""Converts multi-hashes to the hexadecimal representation
of the given hash algo."""
return hash_to_hex(obj_id[algo])
[docs]
def objid_for_content(content: bytes) -> HashDict:
"""Compute the content's hashes.
Args:
content: The raw content to hash
hash_names: Names of hashing algorithms
(default to :const:`swh.model.hashutil.DEFAULT_ALGORITHMS`)
Returns:
A dict mapping algo name to hash value
"""
return objid_from_dict(
MultiHash.from_data(
content,
).digest()
)
[docs]
class NullCompressor:
[docs]
def compress(self, data):
return data
[docs]
def flush(self):
return b""
[docs]
class NullDecompressor:
[docs]
def decompress(self, data: bytes) -> bytes:
return data
@property
def unused_data(self) -> bytes:
return b""
class _CompressorProtocol(Protocol):
def compress(self, data: bytes) -> bytes: ...
def flush(self) -> bytes: ...
class _DecompressorProtocol(Protocol):
def decompress(self, data: bytes) -> bytes: ...
@property
def unused_data(self) -> bytes: ...
decompressors: Dict[str, Callable[[], _DecompressorProtocol]] = {
"bz2": bz2.BZ2Decompressor,
"lzma": lzma.LZMADecompressor,
"gzip": lambda: zlib.decompressobj(wbits=31),
"zlib": zlib.decompressobj,
"none": NullDecompressor,
}
compressors: Dict[str, Callable[[], _CompressorProtocol]] = {
"bz2": bz2.BZ2Compressor,
"lzma": lzma.LZMACompressor,
"gzip": lambda: zlib.compressobj(wbits=31),
"zlib": zlib.compressobj,
"none": NullCompressor,
}
CompressionFormat = Literal["bz2", "lzma", "gzip", "zlib", "none"]
[docs]
class ObjStorage(ObjStorageInterface, metaclass=abc.ABCMeta):
"""Abstract base class for an object storage backend.
It notably adds default implementation for the :meth:`add_batch` and
:meth:`get_batch` methods, by calling :meth:`add` or :meth:`get` method
in a loop. By setting the ``batch_with_threads`` constructor parameter to
:const:`True` (default is :const:`False`), those calls are made concurrently
using a pool of threads as it can improve performance depending on the final
object storage backend implementation (the
:class:`swh.objstorage.backends.pathslicing.PathSlicingObjStorage` backend performs
faster with threaded batch operations for instance). The maximum number of
threads in the pool can be adjusted with the ``max_batch_workers`` constructor
parameter, default is 10.
"""
primary_hash: Optional[LiteralPrimaryHash] = None
compression: CompressionFormat = "none"
name: str = "objstorage"
"""Default objstorage name; can be overloaded at instantiation time giving a
'name' argument to the constructor"""
def __init__(
self,
*,
allow_delete: bool = False,
primary_hash: Optional[LiteralPrimaryHash] = None,
batch_with_threads: bool = False,
max_batch_workers: int = 10,
**kwargs,
):
# A more complete permission system could be used in place of that if
# it becomes needed
self.allow_delete = allow_delete
if primary_hash is not None:
self.primary_hash = primary_hash
# if no name is given in kwargs, default to name defined as class attribute
if "name" in kwargs:
self.name = kwargs["name"]
self.map_contents: Callable = map
if batch_with_threads:
self.executor = ThreadPoolExecutor(max_workers=max_batch_workers)
self.map_contents = self.executor.map
def _add(
self, content: Tuple[HashDict, bytes], check_presence: bool
) -> Tuple[int, int]:
obj_id, content_data = content
if check_presence and obj_id in self:
return (0, 0)
else:
self.add(content_data, obj_id, check_presence=False)
return (1, len(content_data))
[docs]
def add_batch(
self,
contents: Iterable[Tuple[HashDict, bytes]],
check_presence: bool = True,
) -> Dict:
summary = {"object:add": 0, "object:add:bytes": 0}
add = functools.partial(self._add, check_presence=check_presence)
for added, length in self.map_contents(add, contents):
summary["object:add"] += added
summary["object:add:bytes"] += length
return summary
[docs]
def restore(self: ObjStorageInterface, content: bytes, obj_id: HashDict) -> None:
# check_presence to false will erase the potential previous content.
self.add(content, obj_id, check_presence=False)
def _get(self, obj_id: HashDict) -> Optional[bytes]:
try:
return self.get(obj_id)
except ObjNotFoundError:
return None
[docs]
def get_batch(self, obj_ids: Iterable[HashDict]) -> Iterator[Optional[bytes]]:
yield from self.map_contents(self._get, obj_ids)
[docs]
@abc.abstractmethod
def delete(self, obj_id: HashDict):
if not self.allow_delete:
raise PermissionError("Delete is not allowed.")
[docs]
def download_url(
self,
obj_id: HashDict,
content_disposition: Optional[str] = None,
expiry: Optional[timedelta] = None,
) -> Optional[str]:
return None
[docs]
@abc.abstractmethod
def get(self, obj_id: HashDict) -> bytes:
raise NotImplementedError()
[docs]
def check(self, obj_id: HashDict) -> None:
"""Check if a content is found and recompute its hash to check integrity."""
obj_data = self.get(obj_id)
data_hashes = objid_for_content(obj_data)
for algo, expected_hash in obj_id.items():
data_hash = data_hashes[algo] # type: ignore[literal-required]
if data_hash != expected_hash:
raise ObjCorruptedError(
f"expected {algo} hash is {hash_to_hex(expected_hash)}, "
f"data {algo} hash is {hash_to_hex(data_hash)}"
)
[docs]
def compress(self, data: bytes) -> bytes:
compressor = compressors[self.compression]()
compressed = compressor.compress(data)
compressed += compressor.flush()
return compressed
[docs]
def decompress(self, data: bytes, hex_obj_id: str) -> bytes:
decompressor = decompressors[self.compression]()
try:
ret = decompressor.decompress(data)
except (zlib.error, lzma.LZMAError, OSError):
raise ObjCorruptedError(
f"content with {self.primary_hash} hash {hex_obj_id} is not a proper "
"compressed file"
)
if decompressor.unused_data:
raise ObjCorruptedError(
f"trailing data found when decompressing content with {self.primary_hash} "
f"{hex_obj_id}"
)
return ret