Source code for swh.loader.git.dumb

# Copyright (C) 2021-2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from __future__ import annotations

from collections import defaultdict
import copy
import logging
import stat
import struct
from tempfile import SpooledTemporaryFile
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Protocol,
    Set,
    cast,
)
import urllib.parse

from dulwich.errors import NotGitRepository
from dulwich.objects import S_IFGITLINK, Commit, ShaFile, Tree
from dulwich.pack import Pack, PackData, PackIndex, load_pack_index_file
import requests
from tenacity.before_sleep import before_sleep_log

from swh.core.retry import http_retry
from swh.loader.git.utils import HexBytes, PackWriter

if TYPE_CHECKING:
    from .loader import RepoRepresentation

logger = logging.getLogger(__name__)
fetch_pack_logger = logger.getChild("fetch_pack")


[docs] class BytesWriter(Protocol):
[docs] def write(self, data: bytes): ...
[docs] def requests_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: """Inject User-Agent header in the requests kwargs""" ret = copy.deepcopy(kwargs) ret.setdefault("headers", {}).update( {"User-Agent": "Software Heritage dumb Git loader"} ) ret.setdefault("timeout", (120, 60)) return ret
[docs] @http_retry( before_sleep=before_sleep_log(logger, logging.WARNING), ) def check_protocol(repo_url: str, requests_extra_kwargs: Dict[str, Any] = {}) -> bool: """Checks if a git repository can be cloned using the dumb protocol. Args: repo_url: Base URL of a git repository requests_extra_kwargs: extra keyword arguments to be passed to requests, e.g. `timeout`, `verify`. Returns: Whether the dumb protocol is supported. """ if not repo_url.startswith("http"): return False url = urllib.parse.urljoin( repo_url.rstrip("/") + "/", "info/refs?service=git-upload-pack/" ) logger.debug("Fetching %s", url) response = requests.get(url, **requests_kwargs(requests_extra_kwargs)) response.raise_for_status() content_type = response.headers.get("Content-Type") return ( response.status_code in ( 200, 304, ) # header is not mandatory in protocol specification and (content_type is None or not content_type.startswith("application/x-git-")) )
[docs] class GitObjectsFetcher: """Git objects fetcher using dumb HTTP protocol. Fetches a set of git objects for a repository according to its archival state by Software Heritage and provides iterators on them. Args: repo_url: Base URL of a git repository base_repo: State of repository archived by Software Heritage requests_extra_kwargs: extra keyword arguments to be passed to requests, e.g. `timeout`, `verify`. """ def __init__( self, repo_url: str, base_repo: RepoRepresentation, pack_size_limit: int, requests_extra_kwargs: Dict[str, Any] = {}, ): self._session = requests.Session() self.requests_extra_kwargs = requests_extra_kwargs self.repo_url = repo_url self.base_repo = base_repo self.pack_size_limit = pack_size_limit self.objects: Dict[bytes, Set[bytes]] = defaultdict(set) self.refs = self._get_refs() self.head = self._get_head() if self.refs else {} self.packs = self._get_packs()
[docs] def fetch_object_ids(self) -> None: """Fetches identifiers of git objects to load into the archive.""" wants = self.base_repo.determine_wants(self.refs) # process refs commit_objects = [] for ref in wants: ref_object = self._get_git_object(ref) if ref_object.type_num == Commit.type_num: commit_objects.append(cast(Commit, ref_object)) self.objects[b"commit"].add(ref) else: self.objects[b"tag"].add(ref) # perform DFS on commits graph while commit_objects: commit = commit_objects.pop() # fetch tree and blob ids recursively self._fetch_tree_objects(commit.tree) for parent in commit.parents: if ( # commit not already seen in the current load parent not in self.objects[b"commit"] # commit not already archived by a previous load and parent not in self.base_repo.local_heads ): commit_objects.append(cast(Commit, self._get_git_object(parent))) self.objects[b"commit"].add(parent)
[docs] def iter_objects(self, object_type: bytes) -> Iterable[ShaFile]: """Returns a generator on fetched git objects per type. Args: object_type: Git object type, either b"blob", b"commit", b"tag" or b"tree" Returns: A generator fetching git objects on the fly. """ return map(self._get_git_object, self.objects[object_type])
@http_retry( before_sleep=before_sleep_log(logger, logging.WARNING), ) def _http_get(self, path: str) -> SpooledTemporaryFile: url = urllib.parse.urljoin(self.repo_url.rstrip("/") + "/", path) logger.debug("Fetching %s", url) response = self._session.get( url, stream=True, **requests_kwargs(self.requests_extra_kwargs) ) response.raise_for_status() buffer = SpooledTemporaryFile(max_size=100 * 1024 * 1024) bytes_writer: BytesWriter = buffer if path.startswith("objects/pack/") and path.endswith(".pack"): bytes_writer = PackWriter( pack_buffer=buffer, size_limit=self.pack_size_limit, origin_url=self.repo_url, fetch_pack_logger=fetch_pack_logger, ) for chunk in response.iter_content(chunk_size=10 * 1024 * 1024): bytes_writer.write(chunk) buffer.flush() buffer.seek(0) return buffer def _get_refs(self) -> Dict[bytes, HexBytes]: refs = {} refs_resp_bytes = self._http_get("info/refs") for ref_line in refs_resp_bytes.readlines(): ref_target, ref_name = ref_line.replace(b"\n", b"").split(b"\t") refs[ref_name] = ref_target return refs def _get_head(self) -> Dict[bytes, HexBytes]: head_resp_bytes = self._http_get("HEAD") _, head_target = head_resp_bytes.readline().replace(b"\n", b"").split(b" ") return {b"HEAD": head_target} def _get_pack_data(self, pack_name: str) -> Callable[[], PackData]: def _pack_data() -> PackData: pack_data_bytes = self._http_get(f"objects/pack/{pack_name}") return PackData(pack_name, file=pack_data_bytes) return _pack_data def _get_pack_idx(self, pack_idx_name: str) -> Callable[[], PackIndex]: def _pack_idx() -> PackIndex: pack_idx_bytes = self._http_get(f"objects/pack/{pack_idx_name}") return load_pack_index_file(pack_idx_name, pack_idx_bytes) return _pack_idx def _get_packs(self) -> List[Pack]: packs = [] packs_info_bytes = self._http_get("objects/info/packs") packs_info = packs_info_bytes.read().decode() for pack_info in packs_info.split("\n"): if pack_info: pack_name = pack_info.split(" ")[1] pack_idx_name = pack_name.replace(".pack", ".idx") # pack index and data file will be lazily fetched when required packs.append( Pack.from_lazy_objects( self._get_pack_data(pack_name), self._get_pack_idx(pack_idx_name), ) ) return packs def _get_git_object(self, sha: bytes) -> ShaFile: # try to get the object from a pack file first to avoid flooding # git server with numerous HTTP requests for pack in list(self.packs): try: if sha in pack: return pack[sha] except (NotGitRepository, struct.error, requests.HTTPError) as e: if ( isinstance(e, requests.HTTPError) and e.response is not None and e.response.status_code != 404 ): raise # missing or invalid pack index/content, remove it from global packs list logger.debug("A pack file is missing or its content is invalid") self.packs.remove(pack) # fetch it from objects/ directory otherwise sha_hex = sha.decode() object_path = f"objects/{sha_hex[:2]}/{sha_hex[2:]}" return ShaFile.from_file(self._http_get(object_path)) def _fetch_tree_objects(self, sha: bytes) -> None: if sha not in self.objects[b"tree"]: tree = cast(Tree, self._get_git_object(sha)) self.objects[b"tree"].add(sha) for item in tree.items(): if item.mode == S_IFGITLINK: # skip submodules as objects are not stored in repository continue if item.mode & stat.S_IFDIR: self._fetch_tree_objects(item.sha) else: self.objects[b"blob"].add(item.sha)