# Copyright (C) 2015-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import abc
import bz2
import collections
from datetime import timedelta
import functools
from itertools import dropwhile, islice
import lzma
from typing import (
Callable,
Dict,
Iterable,
Iterator,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import zlib
from typing_extensions import Protocol
from swh.core import statsd
from swh.model import hashutil
from swh.model.model import Sha1
from swh.objstorage.constants import DEFAULT_LIMIT, ID_HASH_ALGO
from swh.objstorage.exc import ObjCorruptedError, ObjNotFoundError
from swh.objstorage.interface import CompositeObjId, ObjId, ObjStorageInterface
DURATION_METRICS = "swh_objstorage_request_duration_seconds"
[docs]
def timed(f):
"""A simple decorator used to add statsd probes on main ObjStorage methods
(add, get and __contains__)
"""
@functools.wraps(f)
def w(self, *a, **kw):
with statsd.statsd.timed(
DURATION_METRICS,
tags={"endpoint": f.__name__, "name": self.name},
):
return f(self, *a, **kw)
w._timed = True
w._f = f
return w
[docs]
def objid_to_default_hex(
obj_id: ObjId, algo: Literal["sha1", "sha256"] = ID_HASH_ALGO
) -> str:
"""Converts SHA1 hashes and multi-hashes to the hexadecimal representation
of the SHA1."""
if isinstance(obj_id, bytes):
return hashutil.hash_to_hex(obj_id)
elif isinstance(obj_id, str):
return obj_id
else:
return hashutil.hash_to_hex(obj_id[algo])
[docs]
def compute_hashes(
content: bytes, hash_names: Iterable[str] = hashutil.DEFAULT_ALGORITHMS
) -> Dict[str, bytes]:
"""Compute the content's hashes.
Args:
content: The raw content to hash
hash_names: Names of hashing algorithms
(default to :const:`swh.model.hashutil.DEFAULT_ALGORITHMS`)
Returns:
A dict mapping algo name to hash value
"""
return hashutil.MultiHash.from_data(
content,
hash_names=hash_names,
).digest()
[docs]
def compute_hash(content: bytes, algo: str = ID_HASH_ALGO) -> bytes:
"""Compute the content's hash.
Args:
content: The raw content to hash
hash_name: Hash's name
Returns:
The computed hash for the content
"""
return compute_hashes(content, [algo])[algo]
[docs]
class NullCompressor:
[docs]
def compress(self, data):
return data
[docs]
def flush(self):
return b""
[docs]
class NullDecompressor:
[docs]
def decompress(self, data: bytes) -> bytes:
return data
@property
def unused_data(self) -> bytes:
return b""
class _CompressorProtocol(Protocol):
def compress(self, data: bytes) -> bytes:
...
def flush(self) -> bytes:
...
class _DecompressorProtocol(Protocol):
def decompress(self, data: bytes) -> bytes:
...
unused_data: bytes
decompressors: Dict[str, Callable[[], _DecompressorProtocol]] = {
"bz2": bz2.BZ2Decompressor, # type: ignore
"lzma": lzma.LZMADecompressor, # type: ignore
"gzip": lambda: zlib.decompressobj(wbits=31),
"zlib": zlib.decompressobj,
"none": NullDecompressor, # type: ignore
}
compressors: Dict[str, Callable[[], _CompressorProtocol]] = {
"bz2": bz2.BZ2Compressor,
"lzma": lzma.LZMACompressor,
"gzip": lambda: zlib.compressobj(wbits=31),
"zlib": zlib.compressobj,
"none": NullCompressor,
}
CompressionFormat = Literal["bz2", "lzma", "gzip", "zlib", "none"]
[docs]
class ObjStorage(metaclass=abc.ABCMeta):
PRIMARY_HASH: Literal["sha1", "sha256"] = "sha1"
compression: CompressionFormat = "none"
name: str = "objstorage"
"""Default objstorage name; can be overloaded at instantiation time giving a
'name' argument to the constructor"""
def __init__(
self: ObjStorageInterface,
*,
allow_delete: bool = False,
**kwargs,
):
# A more complete permission system could be used in place of that if
# it becomes needed
self.allow_delete = allow_delete
# if no name is given in kwargs, default to name defined as class attribute
if "name" in kwargs:
self.name = kwargs["name"]
[docs]
def add_batch(
self: ObjStorageInterface,
contents: Union[Mapping[Sha1, bytes], Iterable[Tuple[ObjId, bytes]]],
check_presence: bool = True,
) -> Dict:
summary = {"object:add": 0, "object:add:bytes": 0}
contents_pairs: Iterable[Tuple[ObjId, bytes]]
if isinstance(contents, collections.abc.Mapping):
contents_pairs = contents.items()
else:
contents_pairs = contents
for obj_id, content in contents_pairs:
if check_presence and obj_id in self:
continue
self.add(content, obj_id, check_presence=False)
summary["object:add"] += 1
summary["object:add:bytes"] += len(content)
return summary
[docs]
def restore(self: ObjStorageInterface, content: bytes, obj_id: ObjId) -> None:
# check_presence to false will erase the potential previous content.
self.add(content, obj_id, check_presence=False)
[docs]
def get_batch(
self: ObjStorageInterface, obj_ids: Iterable[ObjId]
) -> Iterator[Optional[bytes]]:
for obj_id in obj_ids:
try:
yield self.get(obj_id)
except ObjNotFoundError:
yield None
[docs]
@abc.abstractmethod
def delete(self, obj_id: ObjId):
if not self.allow_delete:
raise PermissionError("Delete is not allowed.")
[docs]
def list_content(
self: ObjStorageInterface,
last_obj_id: Optional[ObjId] = None,
limit: Optional[int] = DEFAULT_LIMIT,
) -> Iterator[CompositeObjId]:
it = iter(self)
if last_obj_id:
last_obj_id_hex = objid_to_default_hex(last_obj_id)
it = dropwhile(lambda x: objid_to_default_hex(x) <= last_obj_id_hex, it)
return islice(it, limit)
[docs]
def download_url(
self,
obj_id: ObjId,
content_disposition: Optional[str] = None,
expiry: Optional[timedelta] = None,
) -> Optional[str]:
return None
[docs]
@abc.abstractmethod
def get(self, obj_id: ObjId) -> bytes:
raise NotImplementedError()
[docs]
def check(self, obj_id: ObjId) -> None:
"""Check if a content is found and recompute its hash to check integrity."""
obj_content = self.get(obj_id)
hash_algos = [str(self.PRIMARY_HASH)]
if isinstance(obj_id, dict):
hash_algos += [algo for algo in obj_id if algo != self.PRIMARY_HASH]
actual_hashes = compute_hashes(obj_content, hash_algos)
for algo in hash_algos:
actual_obj_id = actual_hashes[algo]
expected_obj_id = obj_id
if isinstance(obj_id, dict):
expected_obj_id = obj_id[algo] # type: ignore[literal-required]
if actual_obj_id != expected_obj_id:
raise ObjCorruptedError(
f"expected {algo} hash is {hashutil.hash_to_hex(expected_obj_id)}, "
f"actual {algo} hash is {hashutil.hash_to_hex(actual_obj_id)}"
)
[docs]
def compress(self, data: bytes) -> bytes:
compressor = compressors[self.compression]()
compressed = compressor.compress(data)
compressed += compressor.flush()
return compressed
[docs]
def decompress(self, data: bytes, hex_obj_id: str) -> bytes:
decompressor = decompressors[self.compression]()
try:
ret = decompressor.decompress(data)
except (zlib.error, lzma.LZMAError, OSError):
raise ObjCorruptedError(
f"content with {self.PRIMARY_HASH} hash {hex_obj_id} is not a proper "
"compressed file"
)
if decompressor.unused_data:
raise ObjCorruptedError(
f"trailing data found when decompressing content with {self.PRIMARY_HASH} "
f"{hex_obj_id}"
)
return ret