Source code for swh.objstorage.backends.http

# Copyright (C) 2021-2023  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from datetime import timedelta
import logging
from typing import Iterator, Optional
from urllib.parse import urljoin

import requests

from swh.model import hashutil
from swh.objstorage import exc
from swh.objstorage.constants import ID_HASH_ALGO
from swh.objstorage.interface import CompositeObjId, ObjId
from swh.objstorage.objstorage import (
    DEFAULT_LIMIT,
    ObjStorage,
    compute_hash,
    decompressors,
    objid_to_default_hex,
)

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.ERROR)


[docs] class HTTPReadOnlyObjStorage(ObjStorage): """Simple ObjStorage retrieving objects from an HTTP server. For example, can be used to retrieve objects from S3: objstorage: cls: http url: https://softwareheritage.s3.amazonaws.com/content/ """ def __init__(self, url=None, compression=None, **kwargs): super().__init__(**kwargs) self.session = requests.sessions.Session() self.root_path = url if not self.root_path.endswith("/"): self.root_path += "/" self.compression = compression
[docs] def check_config(self, *, check_write): """Check the configuration for this object storage""" return True
def __contains__(self, obj_id: ObjId) -> bool: resp = self.session.head(self._path(obj_id)) return resp.status_code == 200 def __iter__(self) -> Iterator[CompositeObjId]: raise exc.NonIterableObjStorage("__iter__") def __len__(self): raise exc.NonIterableObjStorage("__len__")
[docs] def add(self, content: bytes, obj_id: ObjId, check_presence: bool = True) -> None: raise exc.ReadOnlyObjStorage("add")
[docs] def delete(self, obj_id: ObjId): raise exc.ReadOnlyObjStorage("delete")
[docs] def restore(self, content: bytes, obj_id: ObjId) -> None: raise exc.ReadOnlyObjStorage("restore")
[docs] def list_content( self, last_obj_id: Optional[ObjId] = None, limit: Optional[int] = DEFAULT_LIMIT, ) -> Iterator[CompositeObjId]: raise exc.NonIterableObjStorage("__len__")
[docs] def get(self, obj_id: ObjId) -> bytes: try: resp = self.session.get(self._path(obj_id)) resp.raise_for_status() except Exception: raise exc.ObjNotFoundError(obj_id) ret: bytes = resp.content if self.compression: d = decompressors[self.compression]() ret = d.decompress(ret) if d.unused_data: hex_obj_id = objid_to_default_hex(obj_id) raise exc.Error("Corrupt object %s: trailing data found" % hex_obj_id) return ret
[docs] def download_url( self, obj_id: ObjId, content_disposition: Optional[str] = None, expiry: Optional[timedelta] = None, ) -> Optional[str]: return self._path(obj_id)
[docs] def check(self, obj_id: ObjId) -> None: # Check the content integrity obj_content = self.get(obj_id) content_obj_id = compute_hash(obj_content) if content_obj_id != self._hash(obj_id): raise exc.Error(obj_id)
def _hash(self, obj_id: ObjId) -> bytes: if isinstance(obj_id, dict): return obj_id[ID_HASH_ALGO] else: return obj_id def _path(self, obj_id): return urljoin(self.root_path, hashutil.hash_to_hex(self._hash(obj_id)))