# Copyright (C) 2023 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import collections
import contextlib
from datetime import datetime, timezone
import itertools
import logging
import operator
import os
from pathlib import Path
import re
import shlex
import shutil
import subprocess
import tempfile
import typing
from typing import (
Any,
BinaryIO,
Callable,
Collection,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
Union,
)
from zipfile import ZipFile
import attrs
import shamir_mnemonic
from typing_extensions import Self
import yaml
from swh.core.api.classes import stream_results
from swh.core.utils import grouper
from swh.journal.serializers import kafka_to_value, value_to_kafka
from swh.model.exceptions import ValidationError
from swh.model.model import (
BaseModel,
Content,
Directory,
ExtID,
Origin,
OriginVisit,
OriginVisitStatus,
RawExtrinsicMetadata,
Release,
Revision,
SkippedContent,
Snapshot,
)
from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.model.swhids import ObjectType as CoreSWHIDObjectType
import swh.storage.algos.directory
import swh.storage.algos.snapshot
from swh.storage.interface import HashDict, StorageInterface
from .bech32 import Encoding as Bech32Encoding
from .bech32 import bech32_decode, bech32_encode, convert_bits
from .progressbar import ProgressBar, ProgressBarInit, no_progressbar
from .utils import filter_objects_missing_from_storage, iter_swhids_grouped_by_type
logger = logging.getLogger(__name__)
RAGE_PATH = shutil.which("rage")
RAGE_KEYGEN_PATH = shutil.which("rage-keygen")
if RAGE_PATH is None:
raise ImportError("`rage` not found in path")
if RAGE_KEYGEN_PATH is None:
raise ImportError("`rage-keygen` not found in path")
RECOVERY_BUNDLE_RESTORE_CHUNK_SIZE = 200
logger = logging.getLogger(__name__)
class _ManifestDumper(yaml.SafeDumper):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_representer(str, self._represent_str)
self.add_representer(datetime, self._represent_datetime)
self.add_representer(ExtendedSWHID, self._represent_swhid)
self.add_representer(Origin, self._represent_origin)
def _represent_str(self, dumper, data):
if "\n" in data:
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
def _represent_datetime(self, dumper, data):
return dumper.represent_scalar(
"tag:yaml.org,2002:timestamp", data.isoformat(timespec="seconds")
)
def _represent_swhid(self, dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", str(data), style="")
def _represent_origin(self, dumper, data):
return dumper.represent_scalar("tag:yaml.org,2002:str", data.url, style="")
[docs]
def check_call(command: Sequence[str], **kwargs) -> subprocess.CompletedProcess:
try:
return subprocess.run(command, capture_output=True, check=True, **kwargs)
except subprocess.CalledProcessError as e:
logger.warning(
"Command `%s` failed with exit code %s", shlex.join(command), e.returncode
)
for line in e.stderr.strip().splitlines():
logger.warning(" stderr: %s", line.strip())
raise e
[docs]
@attrs.define
class Manifest:
version: int = attrs.field(
validator=[
attrs.validators.instance_of(int),
attrs.validators.ge(1),
attrs.validators.le(3),
]
)
removal_identifier: str = attrs.field(validator=[attrs.validators.instance_of(str)])
created: datetime = attrs.field(validator=attrs.validators.instance_of(datetime))
requested: List[Origin | ExtendedSWHID] = attrs.field(
validator=attrs.validators.instance_of(list)
)
swhids: List[ExtendedSWHID] = attrs.field(
validator=attrs.validators.instance_of(list)
)
referencing: List[ExtendedSWHID] = attrs.field(
validator=attrs.validators.instance_of(list)
)
@requested.validator
def _ensure_requested_length(self, attribute, value):
if self.version >= 3 and len(value) == 0:
raise ValueError("“request” must be a list of ExtendedSWHID or Origin")
@swhids.validator
def _ensure_swhids_type(self, attribute, value):
if not all(isinstance(swhid, ExtendedSWHID) for swhid in value):
raise ValueError("“swhids” must be a list of ExtendedSWHID")
@referencing.validator
def _ensure_referencing_type(self, attribute, value):
if not all(isinstance(swhid, ExtendedSWHID) for swhid in value):
raise ValueError("“referencing” must be a list of ExtendedSWHID")
decryption_key_shares: Dict[str, str] = attrs.field(
validator=[attrs.validators.instance_of(dict), attrs.validators.min_len(2)]
)
reason: Optional[str] = attrs.field(
default=None,
validator=attrs.validators.optional(attrs.validators.instance_of(str)),
)
expire: Optional[datetime] = attrs.field(
default=None,
validator=attrs.validators.optional(attrs.validators.instance_of(datetime)),
)
[docs]
def dump(self, stream: Optional[TextIO] = None) -> Optional[str]:
# recurse=False because we don’t want ExtendedSWHID to be turned into dicts
d = attrs.asdict(self, recurse=False)
for optionals in ("reason", "expire"):
if d[optionals] is None:
del d[optionals]
if self.version < 3:
d.pop("requested", None)
d.pop("referencing", None)
return yaml.dump(
d,
stream=stream,
Dumper=_ManifestDumper,
sort_keys=False,
)
[docs]
@classmethod
def load(cls, str_or_stream: Union[str, TextIO]) -> "Manifest":
d = yaml.safe_load(str_or_stream)
if not isinstance(d, dict):
raise ValueError("Invalid manifest: not a mapping")
if not isinstance(d.get("version"), int):
raise ValueError("Invalid manifest: version missing or not an int")
# Convert `swhids`
if (
"swhids" not in d
or not isinstance(d["swhids"], list)
or len(d["swhids"]) == 0
):
raise ValueError("Invalid manifest: “swhids” is not a list or is empty")
d["swhids"] = [ExtendedSWHID.from_string(s) for s in d["swhids"]]
if d["version"] >= 3:
# Convert `requested`
if (
"requested" in d
and isinstance(d["requested"], list)
and len(d["requested"]) < 1
):
raise ValueError("Invalid manifest: “requested” is not a list or empty")
requested: List[Origin | ExtendedSWHID] = []
for s in d["requested"]:
try:
requested.append(ExtendedSWHID.from_string(s))
except ValidationError:
requested.append(Origin(url=s))
d["requested"] = requested
# Convert `referencing`
if "referencing" in d and not isinstance(d["referencing"], list):
raise ValueError("Invalid manifest: “referencing” is not a list")
d["referencing"] = [
ExtendedSWHID.from_string(s) for s in d.get("referencing", [])
]
else:
d["requested"] = []
d["referencing"] = []
return Manifest(**d)
ShareIdentifier = str
AgePublicKey = str
AgeSecretKey = str
AgeEncryptedPayload = bytes
AgeArmoredEncryptedPayload = str
[docs]
class WrongDecryptionKey(Exception):
pass
[docs]
def age_encrypt(
public_key: AgePublicKey, cleartext: bytes, armored_output=False
) -> AgeEncryptedPayload:
# Make mypy happy
assert RAGE_PATH is not None
cmdline = [RAGE_PATH, "--encrypt", "--recipient", public_key]
# Is output in text mode?
if armored_output:
cmdline.append("--armor")
cmdline.extend(["--output", "-", "-"])
age_proc = check_call(cmdline, input=cleartext)
return age_proc.stdout
[docs]
def age_encrypt_armored(
public_key: AgePublicKey, cleartext: bytes
) -> AgeArmoredEncryptedPayload:
return age_encrypt(public_key, cleartext, armored_output=True).decode("us-ascii")
[docs]
def age_decrypt(
secret_key: AgeSecretKey,
ciphertext: Union[AgeEncryptedPayload, AgeArmoredEncryptedPayload],
) -> bytes:
with tempfile.NamedTemporaryFile("w") as identity_file:
os.chmod(identity_file.name, 0o400)
identity_file.write(secret_key)
identity_file.write("\n")
identity_file.flush()
return age_decrypt_from_identity(identity_file.name, ciphertext)
[docs]
def age_decrypt_from_identity(
identity_file: str,
ciphertext: Union[AgeEncryptedPayload, AgeArmoredEncryptedPayload],
) -> bytes:
if len(ciphertext) == 0:
raise ValueError("ciphertext cannot be empty")
# Make mypy happy
assert RAGE_PATH is not None
if isinstance(ciphertext, str):
ciphertext = ciphertext.encode("us-ascii")
cmdline = [
RAGE_PATH,
"--decrypt",
"--identity",
identity_file,
"--output",
"-",
"-",
]
age_proc = subprocess.run(cmdline, input=ciphertext, capture_output=True)
if age_proc.returncode != 0 and b"No matching keys found" in age_proc.stderr:
raise WrongDecryptionKey()
age_proc.check_returncode()
return age_proc.stdout
[docs]
def generate_age_keypair() -> Tuple[AgePublicKey, AgeSecretKey]:
# Make mypy happy
assert RAGE_KEYGEN_PATH is not None
cmdline = [RAGE_KEYGEN_PATH]
rage_keygen_proc = check_call(cmdline, text=True)
public_key_matches = re.search(
r"^# public key: (age1.*)$", rage_keygen_proc.stdout, re.MULTILINE
)
assert public_key_matches
public_key = public_key_matches.group(1)
secret_key_matches = re.search(
r"^(AGE-SECRET-KEY-.*)$", rage_keygen_proc.stdout, re.MULTILINE
)
assert secret_key_matches
secret_key = secret_key_matches.group(1)
return (public_key, secret_key)
[docs]
def list_yubikey_identities() -> List[Tuple[ShareIdentifier, AgeSecretKey]]:
age_plugin_yubikey_path = shutil.which("age-plugin-yubikey")
if age_plugin_yubikey_path is None:
raise FileNotFoundError("`age-plugin-yubikey` not found in path")
cmdline = [age_plugin_yubikey_path, "--identity"]
age_plugin_yubikey_proc = check_call(cmdline, text=True)
# Split on empty lines
descriptions = age_plugin_yubikey_proc.stdout.split("\n\n")
identities = []
for description in descriptions:
if len(description.strip()) == 0:
continue
yubikey_identifier_matches = re.search(
r"^#[ ]+Serial: ([^,]+), Slot: (\S+)", description, flags=re.MULTILINE
)
age_secret_key_matches = re.search(
r"^(AGE-PLUGIN-YUBIKEY-\S+)", description, flags=re.MULTILINE
)
if not yubikey_identifier_matches or not age_secret_key_matches:
raise ValueError("Unable to parse `age-yubikey-plugin --identity` output")
yubikey_identifier = (
f"YubiKey serial {yubikey_identifier_matches.group(1)} "
f"slot {yubikey_identifier_matches.group(2)}"
)
age_secret_key = age_secret_key_matches.group(1)
identities.append((yubikey_identifier, age_secret_key))
return identities
@attrs.define
class _SecretSharingGroup:
minimum_required_shares: int = attrs.field(
validator=attrs.validators.instance_of(int)
)
recipient_keys: Dict[ShareIdentifier, AgePublicKey] = attrs.field(
validator=attrs.validators.instance_of(dict)
)
def group_parameters(self) -> Tuple[int, int]:
if self.minimum_required_shares == 1:
return (1, 1)
return (self.minimum_required_shares, len(self.recipient_keys))
[docs]
@attrs.define
class SecretSharing:
minimum_required_groups: int = attrs.field(
validator=[attrs.validators.instance_of(int), attrs.validators.ge(2)]
)
groups: Dict[str, _SecretSharingGroup] = attrs.field(
validator=attrs.validators.instance_of(dict)
)
@groups.validator
def _check_groups_len(self, attribute, groups):
if len(groups) < self.minimum_required_groups:
raise ValueError("Not enough groups according to the minimum required")
@groups.validator
def _check_groups_no_duplicate_identifier(self, _attribute, groups):
share_ids = [
share_id for g in groups.values() for share_id in g.recipient_keys.keys()
]
if len(share_ids) != len(set(share_ids)):
raise ValueError("Duplicate share identifier")
@groups.validator
def _check_groups_no_duplicate_recipient(self, _attribute, groups):
recipients = [
recipient_key
for g in groups.values()
for recipient_key in g.recipient_keys.values()
]
if len(recipients) != len(set(recipients)):
raise ValueError("Duplicate recipient public key")
[docs]
@classmethod
def from_dict(cls, d: dict) -> Self:
if not isinstance(d, dict):
raise ValueError("Secret sharing configuration is missing")
if "groups" not in d:
raise ValueError("Configuration is missing group description")
if not isinstance(d["groups"], dict):
raise ValueError("`groups` must be a dict")
if "minimum_required_groups" not in d:
raise ValueError("Configuration is missing `minimum_required_groups`")
return cls(
minimum_required_groups=d["minimum_required_groups"],
groups={
name: _SecretSharingGroup(**group_d)
for name, group_d in d["groups"].items()
},
)
@property
def share_ids(self) -> Set[ShareIdentifier]:
return {
share_id
for group in self.groups.values()
for share_id in group.recipient_keys.keys()
}
def _generate_mnemonics(
self, secret_key: AgeSecretKey
) -> Iterable[Tuple[_SecretSharingGroup, List[str]]]:
hrp, data, spec = bech32_decode(secret_key)
assert hrp == "age-secret-key-"
secret_key_data = bytes(convert_bits(data, 5, 8))
assert len(secret_key_data) == 32
mnemonics = shamir_mnemonic.generate_mnemonics(
self.minimum_required_groups,
[g.group_parameters() for g in self.groups.values()],
secret_key_data,
)
return zip(self.groups.values(), mnemonics)
[docs]
def generate_encrypted_shares(
self, identifier: str, secret_key: AgeSecretKey
) -> Dict[ShareIdentifier, AgeArmoredEncryptedPayload]:
encrypted_shares: Dict[ShareIdentifier, AgeArmoredEncryptedPayload] = {}
for group, mnemonics in self._generate_mnemonics(secret_key):
# When a group require only one share to be complete, we actually
# implement this by giving to to everyone in the group the same
# share (as required by shamir-mnemonic because of “the maths”).
mnemonics_iter = (
itertools.repeat(mnemonics[0])
if group.minimum_required_shares == 1
else iter(mnemonics)
)
for share_id, recipient_key, mnemonic in zip(
group.recipient_keys.keys(),
group.recipient_keys.values(),
mnemonics_iter,
):
# We prefix each mnemonic by the removal identifier in our encrypted
# payload. In case a share holder remotely decrypts their payload, they can
# verify it came from the right recovery bundle before sending back the
# decrypted mnemonics.
cleartext = f"[{identifier}] {mnemonic}"
encrypted_shares[share_id] = age_encrypt_armored(
recipient_key, cleartext.encode("us-ascii")
)
return encrypted_shares
[docs]
class SecretRecoveryError(Exception):
pass
ObjectDecryptionKeyProvider = Callable[[Manifest], AgeSecretKey]
ShareDecryptionKeys = Iterator[Tuple[ShareIdentifier, AgeSecretKey]]
ShareDecryptionKeysProvider = Callable[[], ShareDecryptionKeys]
[docs]
def recover_object_decryption_key_from_encrypted_shares(
encrypted_shares: Dict[ShareIdentifier, AgeArmoredEncryptedPayload],
share_decryption_keys_provider: ShareDecryptionKeysProvider,
decrypted_mnemonic_processor: Optional[Callable[[str, Optional[str]], None]] = None,
known_mnemonics: Optional[List[str]] = None,
) -> str:
from shamir_mnemonic.recovery import RecoveryState
from shamir_mnemonic.share import Share
def mnemonics_from_known() -> Iterator[str]:
if known_mnemonics:
yield from known_mnemonics
def mnemonics_from_provider() -> Iterator[str]:
for share_id, secret_key in share_decryption_keys_provider():
mnemonic = age_decrypt(secret_key, encrypted_shares[share_id]).decode(
"us-ascii"
)
if decrypted_mnemonic_processor:
decrypted_mnemonic_processor(mnemonic, share_id)
yield mnemonic
# TODO: We could provide better feedback on how our progress. This
# would require a different API though.
# shamir-mnemonic command line outputs something like:
# Completed 1 of 3 groups needed:
# ✓ 1 of 1 shares needed from group union echo acrobat
# ● 1 of 2 shares needed from group union echo beard
# ✗ 0 shares from group union echo check
# For implementation see:
# https://github.com/trezor/python-shamir-mnemonic/blob/c919df72/shamir_mnemonic/cli.py#L156-196
recovery_state = RecoveryState()
for mnemonic in itertools.chain(mnemonics_from_known(), mnemonics_from_provider()):
# Strip bundle removal identifier if it was given
mnemonic = re.sub(r"^\[.*\] ([a-z ]+)$", R"\1", mnemonic)
share = Share.from_mnemonic(mnemonic)
recovery_state.add_share(share)
if recovery_state.is_complete():
# no passphrase has been set when creating the mnemonics, so we need
# to pass an empty string
secret_key_data = recovery_state.recover(passphrase=b"")
assert len(secret_key_data) == 32
return bech32_encode(
"age-secret-key-",
list(convert_bits(secret_key_data, 8, 5, True)),
Bech32Encoding.BECH32,
).upper()
raise SecretRecoveryError("Unable to decrypt enough secrets")
[docs]
class UnsupportedFeatureException(Exception):
pass
MANIFEST_ARCNAME = "manifest.yml"
[docs]
class RecoveryBundle:
def __init__(
self,
path: str,
object_decryption_key_provider: Optional[ObjectDecryptionKeyProvider] = None,
):
self._zip = ZipFile(path, "r")
self._manifest = Manifest.load(self._zip.read(MANIFEST_ARCNAME).decode("utf-8"))
self._cached_object_decryption_key: Optional[str] = None
if object_decryption_key_provider:
self._object_decryption_key_provider = object_decryption_key_provider
else:
def failing_provider(_):
raise ValueError(
"No `object_decryption_key_provider` has been given for this bundle."
)
self._object_decryption_key_provider = failing_provider
@property
def version(self) -> int:
return self._manifest.version
@property
def removal_identifier(self) -> str:
return self._manifest.removal_identifier
@property
def created(self) -> datetime:
return self._manifest.created
@property
def requested(self) -> List[Origin | ExtendedSWHID]:
if self.version < 3:
raise UnsupportedFeatureException(
f"`requested` is not available on recovery bundle version {self.version}"
)
return self._manifest.requested
@property
def swhids(self) -> List[ExtendedSWHID]:
return self._manifest.swhids
@property
def referencing(self) -> List[ExtendedSWHID]:
if self.version < 3:
raise UnsupportedFeatureException(
f"`referencing` is not available on recovery bundle version {self.version}"
)
return self._manifest.referencing
@property
def reason(self) -> Optional[str]:
return self._manifest.reason
@property
def expire(self) -> Optional[datetime]:
return self._manifest.expire
@property
def share_ids(self) -> Set[ShareIdentifier]:
return set(self._manifest.decryption_key_shares.keys())
@property
def object_decryption_key(self) -> AgeSecretKey:
if self._cached_object_decryption_key is None:
self._cached_object_decryption_key = self._object_decryption_key_provider(
self._manifest
)
return self._cached_object_decryption_key
[docs]
def encrypted_secret(self, share_id: ShareIdentifier) -> AgeArmoredEncryptedPayload:
return self._manifest.decryption_key_shares[share_id]
[docs]
def dump_manifest(self) -> str:
result = self._manifest.dump()
# make mypy happy
assert result is not None
return result
def _extract(self, arcname: str) -> bytes:
with self._zip.open(arcname) as f:
return age_decrypt(self.object_decryption_key, f.read())
[docs]
def get_dict(self, swhid: ExtendedSWHID) -> Dict[str, Any]:
arcname = _swhid_to_arcname(swhid)
return kafka_to_value(self._extract(arcname))
[docs]
def write_content_data(self, swhid: ExtendedSWHID, dest: BinaryIO):
content: Content = Content.from_dict(self.get_dict(swhid))
if content.data is None:
raise ValueError("Unserialized Content has no data")
dest.write(content.data)
def _objects(
self,
dir: str,
cls: Type[BaseModel],
name_filter: Optional[Callable[[str], bool]] = None,
):
if name_filter is None:
name_filter = lambda name: True # noqa: E731
for zip_info in sorted(
self._zip.infolist(), key=operator.attrgetter("filename")
):
if not zip_info.filename.startswith(f"{dir}/"):
continue
if zip_info.is_dir():
continue
if not name_filter(zip_info.filename.split("/")[-1]):
continue
d = kafka_to_value(
age_decrypt(self.object_decryption_key, self._zip.read(zip_info))
)
yield cls.from_dict(d)
[docs]
def contents(self) -> Iterator[Content]:
yield from self._objects("contents", Content)
[docs]
def skipped_contents(self) -> Iterator[SkippedContent]:
yield from self._objects("skipped_contents", SkippedContent)
[docs]
def directories(self) -> Iterator[Directory]:
yield from self._objects("directories", Directory)
[docs]
def revisions(self) -> Iterator[Revision]:
yield from self._objects("revisions", Revision)
[docs]
def releases(self) -> Iterator[Release]:
yield from self._objects("releases", Release)
[docs]
def snapshots(self) -> Iterator[Snapshot]:
yield from self._objects("snapshots", Snapshot)
[docs]
def origins(self) -> Iterator[Origin]:
yield from self._objects("origins", Origin)
[docs]
def origin_visits(self, origin: Origin) -> Iterator[OriginVisit]:
basename = str(origin.swhid()).replace(":", "_")
yield from self._objects(
"origin_visits", OriginVisit, lambda name: name.startswith(basename)
)
[docs]
def origin_visit_statuses(self, origin: Origin) -> Iterator[OriginVisitStatus]:
basename = str(origin.swhid()).replace(":", "_")
yield from self._objects(
"origin_visit_statuses",
OriginVisitStatus,
lambda name: name.startswith(basename),
)
[docs]
def extids(self) -> Iterator[ExtID]:
if self.version < 2:
return
yield from self._objects("extids", ExtID)
[docs]
def get_missing_referenced_objects(
self, storage: StorageInterface
) -> Set[ExtendedSWHID]:
if self.version < 3:
raise UnsupportedFeatureException(
"`get_missing_referenced_objects` is not supported on "
f"recovery bundle version {self.version}"
)
available = filter_objects_missing_from_storage(storage, self.referencing)
return set(self.referencing) - set(available)
[docs]
def restore(
self, storage: StorageInterface, progressbar: ProgressBarInit = no_progressbar
) -> Dict[str, int]:
def _origin_add(origins: List[Origin]) -> Dict[str, int]:
origin_result: collections.Counter[str] = collections.Counter()
origin_result += storage.origin_add(origins)
for origin in origins:
# Interestingly enough, origin_visit_add() and origin_visit_status_add()
# do not return result info.
# Also you _do_ need to pass a list and not an iterator otherwise
# nothing gets added.
origin_visits = list(self.origin_visits(origin))
storage.origin_visit_add(origin_visits)
origin_result["origin_visit:add"] += len(origin_visits)
origin_visit_statuses = list(self.origin_visit_statuses(origin))
storage.origin_visit_status_add(origin_visit_statuses)
origin_result["origin_visit_status:add"] += len(origin_visit_statuses)
return dict(origin_result)
steps: List[
Tuple[Callable[[List[Any]], Dict[str, int]], Callable[[], Iterator[Any]]]
] = [
(storage.content_add, self.contents),
(storage.skipped_content_add, self.skipped_contents),
(storage.directory_add, self.directories),
(storage.revision_add, self.revisions),
(storage.release_add, self.releases),
(storage.snapshot_add, self.snapshots),
(_origin_add, self.origins),
]
if self.version >= 2:
steps.extend(
[
(storage.raw_extrinsic_metadata_add, self.raw_extrinsic_metadata),
(storage.extid_add, self.extids),
]
)
result: collections.Counter[str] = collections.Counter()
bar: ProgressBar[int]
with progressbar(
length=len(self.swhids), label="Restoring recovery bundle…"
) as bar:
for add, source in steps:
for chunk_it in grouper(source(), RECOVERY_BUNDLE_RESTORE_CHUNK_SIZE):
chunk = list(chunk_it)
result += add(chunk)
result += storage.flush()
bar.update(n_steps=len(chunk))
log_lines = [
"Restoration complete. Results: ",
"- Content objects added: %(content:add)s",
"- Total bytes added to objstorage: %(content:add:bytes)s",
"- SkippedContent objects added: %(skipped_content:add)s",
"- Directory objects added: %(directory:add)s",
"- Revision objects added: %(revision:add)s",
"- Release objects added: %(release:add)s",
"- Snapshot objects added: %(snapshot:add)s",
"- Origin objects added: %(origin:add)s",
"- OriginVisit objects added: %(origin_visit:add)s",
"- OriginVisitStatus objects added: %(origin_visit_status:add)s",
]
if "ori_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for origins added: %(ori_metadata:add)s"
)
if "snp_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for snapshots added: %(snp_metadata:add)s"
)
if "rev_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for revisions added: %(rev_metadata:add)s"
)
if "rel_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for releases added: %(rel_metadata:add)s"
)
if "dir_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for directories added: %(dir_metadata:add)s"
)
if "cnt_metadata:add" in result:
log_lines.append(
"- RawExtrinsicMetadata objects for contents added: %(cnt_metadata:add)s"
)
if "extid:add" in result:
log_lines.append("- ExtID objects added: %(extid:add)s")
logger.info("\n".join(log_lines), result)
return dict(result)
[docs]
def rollover(self, secret_sharing: SecretSharing):
"""Update the recovery bundle encrypted shared secrets using the given
configuration.
It is useful when a secret holder needs to be added or removed, or to
switch to an entirely new scheme.
This method splits the decryption key into new encrypted shares. The
decryption key stays the same. The mnemonics will be new.
A new recovery bundle file is created with an updated manifest which
then atomically replaces the existing file.
"""
new_decryption_key_shares = secret_sharing.generate_encrypted_shares(
self.removal_identifier, self.object_decryption_key
)
# Turns out there is no way currently to update or delete a member of a Zip archive
# in Python zipfile module. See https://github.com/python/cpython/issues/51067
# So we are going to manually copy all members from the original zip to the other.
# This means fully unpacking each file in memory.
#
# We are going to create the new zip in the same directory as the original
# so we can do an atomic replacement using rename(2).
assert self._zip.filename is not None
bundle_path = Path(self._zip.filename)
bundle_dir = bundle_path.parent
with tempfile.NamedTemporaryFile(
mode="w+b",
dir=bundle_dir,
delete=False,
prefix=f".{bundle_path.stem}_",
suffix=".swh-recovery-bundle",
) as f:
try:
self._manifest.decryption_key_shares = new_decryption_key_shares
with ZipFile(f, "a") as new_zip:
for zipinfo in self._zip.infolist():
# We skip the old manifest…
if zipinfo.filename == MANIFEST_ARCNAME:
continue
new_zip.writestr(zipinfo, self._zip.read(zipinfo))
# …and add the new one at the end.
new_zip.writestr(
MANIFEST_ARCNAME, typing.cast(str, self._manifest.dump())
)
f.close()
os.rename(f.name, bundle_path)
# Reopen the current zip file
self._zip = ZipFile(bundle_path, "r")
finally:
# Always unlink the temporary file path. Either it already has
# been renamed to the old file, or something went wrong.
with contextlib.suppress(FileNotFoundError):
os.unlink(f.name)
[docs]
class ContentDataNotFound(Exception):
"""Raised when data for a given Content object cannot be retrieved."""
def __init__(self, swhid: ExtendedSWHID):
self.swhid = swhid
def __str__(self):
return f"No data found for {self.swhid}"
def _swhid_to_arcname(swhid: ExtendedSWHID):
basename = str(swhid).replace(":", "_")
if swhid.object_type == ExtendedObjectType.CONTENT:
return f"contents/{basename}.age"
if swhid.object_type == ExtendedObjectType.DIRECTORY:
return f"directories/{basename}.age"
if swhid.object_type == ExtendedObjectType.REVISION:
return f"revisions/{basename}.age"
if swhid.object_type == ExtendedObjectType.RELEASE:
return f"releases/{basename}.age"
if swhid.object_type == ExtendedObjectType.SNAPSHOT:
return f"snapshots/{basename}.age"
if swhid.object_type == ExtendedObjectType.ORIGIN:
return f"origins/{basename}.age"
raise NotImplementedError(f"Unknown object type {swhid.object_type}")
def _from_hashes(
sha1: Optional[bytes] = None,
sha1_git: Optional[bytes] = None,
sha256: Optional[bytes] = None,
blake2s256: Optional[bytes] = None,
) -> HashDict:
d = HashDict()
if sha1 is not None:
d["sha1"] = sha1
if sha1_git is not None:
d["sha1_git"] = sha1_git
if sha256 is not None:
d["sha256"] = sha256
if blake2s256 is not None:
d["blake2s256"] = blake2s256
return d
[docs]
class RecoveryBundleCreator:
def __init__(
self,
path: str,
storage: StorageInterface,
removal_identifier: str,
requested: List[Origin | ExtendedSWHID],
referencing: List[ExtendedSWHID],
object_public_key: AgePublicKey,
decryption_key_shares: Dict[str, str],
registration_callback: Optional[Callable[[BaseModel], None]] = None,
allow_empty_content_objects: bool = False,
):
self._path = path
self._storage = storage
self._removal_identifier = removal_identifier
self._requested = requested
self._swhids: List[ExtendedSWHID] = []
self._referencing = referencing
self._created = datetime.now(timezone.utc)
self._pk = object_public_key
if len(decryption_key_shares) == 0:
raise ValueError("`decryption_key_shares` has not been set")
self._decryption_key_shares = decryption_key_shares
self._reason: Optional[str] = None
self._expire: Optional[datetime] = None
if registration_callback:
self._registration_callback = registration_callback
else:
self._registration_callback = lambda _: None
self._allow_empty_content_objects = allow_empty_content_objects
# Total number of RawExtrinsingMetadata that will be added
self._total_emds: Optional[int] = None
# Current number of RawExtrinsingMetadata objects seen
self._seen_emds = 0
def __enter__(self) -> Self:
self._zip = ZipFile(self._path, "x")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if exc_type is not None:
with contextlib.suppress(FileNotFoundError):
os.unlink(self._path)
return False
if len(self._swhids) == 0:
raise ValueError("Refusing to create an empty recovery bundle")
manifest = Manifest(
version=3,
removal_identifier=self._removal_identifier,
created=self._created,
requested=self._requested,
swhids=self._swhids,
referencing=self._referencing,
decryption_key_shares=self._decryption_key_shares,
reason=self._reason,
expire=self._expire,
)
self._zip.writestr(MANIFEST_ARCNAME, manifest.dump())
except: # noqa: E722
with contextlib.suppress(FileNotFoundError):
os.unlink(self._path)
raise
finally:
self._zip.close()
[docs]
def set_reason(self, reason: str):
self._reason = reason
[docs]
def set_expire(self, expire: datetime):
if expire < self._created:
raise ValueError("expiration date is in the past")
self._expire = expire
def _write(self, arcname: str, data: bytes):
self._zip.writestr(arcname, age_encrypt(self._pk, data))
def _add_skipped_content(
self, swhid: ExtendedSWHID, index: int, skipped_content: SkippedContent
):
basename = str(swhid).replace(":", "_")
arcname = f"skipped_contents/{basename}_{index}.age"
self._write(arcname, value_to_kafka(skipped_content.to_dict()))
def _add_contents(
self, content_swhids: List[ExtendedSWHID]
) -> Iterable[Content | SkippedContent | ExtID]:
assert all(
swhid.object_type == ExtendedObjectType.CONTENT for swhid in content_swhids
)
for swhid, content in zip(
content_swhids,
self._storage.content_get(
[swhid.object_id for swhid in content_swhids], algo="sha1_git"
),
):
if content is None:
# content_get() gave us nothing… maybe the SWHID matches some SkippedContent?
skipped_contents = self._storage.skipped_content_find(
{"sha1_git": swhid.object_id}
)
if len(skipped_contents) == 0:
raise ValueError(f"Unable to find {swhid} in storage")
for index, skipped_content in enumerate(skipped_contents, start=1):
self._add_skipped_content(swhid, index, skipped_content)
yield skipped_content
else:
data = self._storage.content_get_data(_from_hashes(**content.hashes()))
if data is not None:
populated_content = content.from_data(
data,
status=content.status,
ctime=content.ctime,
)
else:
if self._allow_empty_content_objects:
logger.warning(
"No data available for %s. "
"Recording empty Content object as requested.",
swhid,
)
populated_content = content
else:
raise ContentDataNotFound(swhid)
self._write(
_swhid_to_arcname(swhid),
value_to_kafka(populated_content.to_dict()),
)
yield populated_content
yield from self._add_extids(
CoreSWHIDObjectType.CONTENT, [swhid.object_id for swhid in content_swhids]
)
def _add_directories(
self, directory_swhids: List[ExtendedSWHID]
) -> Iterable[Directory | ExtID]:
assert all(
swhid.object_type == ExtendedObjectType.DIRECTORY
for swhid in directory_swhids
)
directory_ids = [swhid.object_id for swhid in directory_swhids]
it = zip(
directory_swhids,
swh.storage.algos.directory.directory_get_many_with_possibly_duplicated_entries( # noqa: B950
self._storage, directory_ids
),
)
for swhid, result in it:
if result is None:
raise ValueError(f"Unable to find {swhid} in storage")
_corrupted, directory = result
# If it's corrupted we still should backup it anyway
self._write(_swhid_to_arcname(swhid), value_to_kafka(directory.to_dict()))
yield directory
yield from self._add_extids(
CoreSWHIDObjectType.DIRECTORY,
[swhid.object_id for swhid in directory_swhids],
)
def _add_revisions(
self, revision_swhids: List[ExtendedSWHID]
) -> Iterator[Revision | ExtID]:
assert all(
swhid.object_type == ExtendedObjectType.REVISION
for swhid in revision_swhids
)
for swhid, revision in zip(
revision_swhids,
self._storage.revision_get(
[swhid.object_id for swhid in revision_swhids], ignore_displayname=True
),
):
if revision is None:
raise ValueError(f"Unable to find {swhid} in storage")
self._write(_swhid_to_arcname(swhid), value_to_kafka(revision.to_dict()))
yield revision
yield from self._add_extids(
CoreSWHIDObjectType.REVISION, [swhid.object_id for swhid in revision_swhids]
)
def _add_releases(
self, release_swhids: List[ExtendedSWHID]
) -> Iterator[Release | ExtID]:
assert all(
swhid.object_type == ExtendedObjectType.RELEASE for swhid in release_swhids
)
for swhid, release in zip(
release_swhids,
self._storage.release_get(
[swhid.object_id for swhid in release_swhids], ignore_displayname=True
),
):
if release is None:
raise ValueError(f"Unable to find {swhid} in storage")
self._write(_swhid_to_arcname(swhid), value_to_kafka(release.to_dict()))
yield release
yield from self._add_extids(
CoreSWHIDObjectType.RELEASE, [swhid.object_id for swhid in release_swhids]
)
def _add_snapshots(
self, snapshot_swhids: List[ExtendedSWHID]
) -> Iterator[Snapshot | ExtID]:
assert all(
swhid.object_type == ExtendedObjectType.SNAPSHOT
for swhid in snapshot_swhids
)
for swhid in snapshot_swhids:
snapshot = swh.storage.algos.snapshot.snapshot_get_all_branches(
self._storage, swhid.object_id
)
if snapshot is None:
raise ValueError(f"Unable to find {swhid} in storage")
self._write(_swhid_to_arcname(swhid), value_to_kafka(snapshot.to_dict()))
yield snapshot
yield from self._add_extids(
CoreSWHIDObjectType.SNAPSHOT, [swhid.object_id for swhid in snapshot_swhids]
)
def _add_origin_visit(self, basename: str, visit: OriginVisit):
arcname = f"origin_visits/{basename}_" f"{visit.visit}.age"
self._write(arcname, value_to_kafka(visit.to_dict()))
def _add_origin_visit_status(self, basename: str, status: OriginVisitStatus):
arcname = (
f"origin_visit_statuses/{basename}_"
f"{status.visit}_"
f"{status.date.isoformat(timespec='microseconds').replace(':', '_')}.age"
)
self._write(arcname, value_to_kafka(status.to_dict()))
def _add_origins(
self, origin_swhids: List[ExtendedSWHID]
) -> Iterator[Origin | OriginVisit | OriginVisitStatus]:
assert all(
swhid.object_type == ExtendedObjectType.ORIGIN for swhid in origin_swhids
)
for swhid, origin_d in zip(
origin_swhids,
self._storage.origin_get_by_sha1(
[swhid.object_id for swhid in origin_swhids]
),
):
if origin_d is None:
raise ValueError(f"Unable to find {swhid} in storage")
origin = Origin.from_dict(origin_d)
basename = str(swhid).replace(":", "_")
arcname = f"origins/{basename}.age"
self._write(arcname, value_to_kafka(origin_d))
yield origin
for origin_visit_with_statuses in stream_results(
self._storage.origin_visit_get_with_statuses, origin.url
):
self._add_origin_visit(basename, origin_visit_with_statuses.visit)
yield origin_visit_with_statuses.visit
for origin_visit_status in origin_visit_with_statuses.statuses:
self._add_origin_visit_status(basename, origin_visit_status)
yield origin_visit_status
def _add_raw_extrinsic_metadata(
self, emd_swhids: List[ExtendedSWHID]
) -> Iterator[RawExtrinsicMetadata]:
assert all(
swhid.object_type == ExtendedObjectType.RAW_EXTRINSIC_METADATA
for swhid in emd_swhids
)
assert (
self._total_emds is not None
), "this should have been set by backup_swhids.chunker()"
emds = self._storage.raw_extrinsic_metadata_get_by_ids(
[swhid.object_id for swhid in emd_swhids]
)
missing_emd_swhids = set(emd_swhids) - {emd.swhid() for emd in emds}
if missing_emd_swhids:
raise ValueError(
f"Unable to retrieve {', '.join(str(swhid) for swhid in missing_emd_swhids)}"
)
# Here we do some tricks to 0-pad the index to the right size
# self._total_emds will contain the total number of RawExtrinsicMetadata
# objects. If it’s 1234 `len(str(self._total_emds))`` will be 4.
# We thus get `emd_index_format` to be`{:04d}`.
# Calling `emd_index_format.format(33)` results in `0033`.
emd_index_format = "{" f":0{len(str(self._total_emds))}d" "}"
for emd in emds:
self._seen_emds += 1
basename = str(emd.swhid()).replace(":", "_")
arcname = (
"raw_extrinsic_metadata/"
f"{emd_index_format.format(self._seen_emds)}_{basename}.age"
)
self._write(arcname, value_to_kafka(emd.to_dict()))
yield emd
def _add_extids(
self, target_type: CoreSWHIDObjectType, target_ids: List[bytes]
) -> Iterator[ExtID]:
for extid in self._storage.extid_get_from_target(target_type, target_ids):
assert extid.id, "id for ExtID should have been computed!"
arcname = f"extids/{extid.id.hex()}.age"
self._write(arcname, value_to_kafka(extid.to_dict()))
yield extid
[docs]
def backup_swhids(
self,
swhids: Iterable[ExtendedSWHID],
progressbar: ProgressBarInit = no_progressbar,
) -> None:
def chunker(
grouped_swhids: Collection[ExtendedSWHID],
) -> Iterable[List[ExtendedSWHID]]:
assert (
len(grouped_swhids) > 0
), "iter_swhids_grouped_by_type() should not give us an empty list"
swhids_it = iter(sorted(grouped_swhids))
first_swhid = next(swhids_it)
match first_swhid.object_type:
case ExtendedObjectType.RAW_EXTRINSIC_METADATA:
# As we want to number RawExtrinsingMetadata with proper padding,
# we need to record the total size of the group
if self._total_emds is None:
self._total_emds = len(grouped_swhids)
else:
raise ValueError(
"RawExtrinsingMetdata objects must all be added in one batch"
)
chunk_size = 20
case ExtendedObjectType.CONTENT:
# Content can be slower as we need to retrieve from the objstorage,
# so let’s update progress more often.
chunk_size = 10
case _:
chunk_size = 50
bar: ProgressBar[int]
with progressbar(
length=len(grouped_swhids),
label=f"Backing up {first_swhid.object_type.name.capitalize()} objects…",
) as bar:
yield [first_swhid]
bar.update(n_steps=1)
for chunk_it in grouper(swhids_it, chunk_size):
chunk = list(chunk_it)
yield chunk
bar.update(n_steps=len(chunk))
handlers: Dict[
ExtendedObjectType,
Callable[[List[ExtendedSWHID]], Iterable[BaseModel]],
] = {
ExtendedObjectType.CONTENT: self._add_contents,
ExtendedObjectType.DIRECTORY: self._add_directories,
ExtendedObjectType.REVISION: self._add_revisions,
ExtendedObjectType.RELEASE: self._add_releases,
ExtendedObjectType.SNAPSHOT: self._add_snapshots,
ExtendedObjectType.ORIGIN: self._add_origins,
ExtendedObjectType.RAW_EXTRINSIC_METADATA: self._add_raw_extrinsic_metadata,
}
for obj in iter_swhids_grouped_by_type(
swhids,
handlers=handlers,
chunker=chunker,
):
self._registration_callback(obj)
if hasattr(obj, "swhid"):
swhid = obj.swhid()
self._swhids.append(
swhid.to_extended() if hasattr(swhid, "to_extended") else swhid
)