# Copyright (C) 2015-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import defaultdict
import datetime
import functools
import itertools
import random
import threading
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from cassandra.util import Date
from swh.model.model import Content, Sha1Git, SkippedContent
from swh.model.swhids import ExtendedSWHID
from swh.storage.cassandra import CassandraStorage
from swh.storage.cassandra.model import (
BaseRow,
ContentRow,
DirectoryEntryRow,
DirectoryRow,
ExtIDByTargetRow,
ExtIDRow,
MetadataAuthorityRow,
MetadataFetcherRow,
ObjectCountRow,
ObjectReferenceRow,
ObjectReferencesTableRow,
OriginRow,
OriginVisitRow,
OriginVisitStatusRow,
RawExtrinsicMetadataByIdRow,
RawExtrinsicMetadataRow,
ReleaseRow,
RevisionParentRow,
RevisionRow,
SkippedContentRow,
SnapshotBranchRow,
SnapshotRow,
)
from swh.storage.cassandra.schema import HASH_ALGORITHMS
from swh.storage.exc import NonRetryableException
from swh.storage.interface import ListOrder, TotalHashDict
from swh.storage.objstorage import ObjStorage
from .common import origin_url_to_sha1
from .writer import JournalWriter
TRow = TypeVar("TRow", bound=BaseRow)
[docs]
class Table(Generic[TRow]):
def __init__(self, row_class: Type[TRow]):
self.row_class = row_class
self.primary_key_cols = row_class.PARTITION_KEY + row_class.CLUSTERING_KEY
# Map from tokens to clustering keys to rows
# These are not actually partitions (or rather, there is one partition
# for each token) and they aren't sorted.
# But it is good enough if we don't care about performance;
# and makes the code a lot simpler.
self.data: Dict[int, Dict[Tuple, TRow]] = defaultdict(dict)
def __repr__(self):
return f"<__module__.Table[{self.row_class.__name__}] object>"
[docs]
def partition_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple:
"""Returns the partition key of a row (ie. the cells which get hashed
into the token."""
if isinstance(row, dict):
row_d = row
else:
row_d = row.to_dict()
return tuple(row_d[col] for col in self.row_class.PARTITION_KEY)
[docs]
def clustering_key(self, row: Union[TRow, Dict[str, Any]]) -> Tuple:
"""Returns the clustering key of a row (ie. the cells which are used
for sorting rows within a partition."""
if isinstance(row, dict):
row_d = row
else:
row_d = row.to_dict()
return tuple(row_d[col] for col in self.row_class.CLUSTERING_KEY)
[docs]
def primary_key(self, row):
return self.partition_key(row) + self.clustering_key(row)
[docs]
def primary_key_from_dict(self, d: Dict[str, Any]) -> Tuple:
"""Returns the primary key (ie. concatenation of partition key and
clustering key) of the given dictionary interpreted as a row."""
return tuple(d[col] for col in self.primary_key_cols)
[docs]
def token(self, key: Tuple):
"""Returns the token of a row (ie. the hash of its partition key)."""
return hash(key)
[docs]
def get_partition(self, token: int) -> Dict[Tuple, TRow]:
"""Returns the partition that contains this token."""
return self.data.get(token, {})
[docs]
def insert(self, row: TRow) -> None:
partition = self.data[self.token(self.partition_key(row))]
partition[self.clustering_key(row)] = row
[docs]
def delete(self, predicate: Callable[[TRow], bool]) -> None:
self.data = {
pk: dict((ck, row) for (ck, row) in partition.items() if not predicate(row))
for (pk, partition) in self.data.items()
}
[docs]
def split_primary_key(self, key: Tuple) -> Tuple[Tuple, Tuple]:
"""Returns (partition_key, clustering_key) from a partition key"""
assert len(key) == len(self.primary_key_cols)
partition_key = key[0 : len(self.row_class.PARTITION_KEY)]
clustering_key = key[len(self.row_class.PARTITION_KEY) :]
return (partition_key, clustering_key)
[docs]
def get_from_partition_key(self, partition_key: Tuple) -> Iterable[TRow]:
"""Returns at most one row, from its partition key."""
token = self.token(partition_key)
for row in self.get_from_token(token):
if self.partition_key(row) == partition_key:
yield row
[docs]
def get_from_primary_key(self, primary_key: Tuple) -> Optional[TRow]:
"""Returns at most one row, from its primary key."""
(partition_key, clustering_key) = self.split_primary_key(primary_key)
token = self.token(partition_key)
partition = self.get_partition(token)
return partition.get(clustering_key)
[docs]
def get_from_token(self, token: int) -> Iterable[TRow]:
"""Returns all rows whose token (ie. non-cryptographic hash of the
partition key) is the one passed as argument."""
return (v for (k, v) in sorted(self.get_partition(token).items()))
[docs]
def iter_all(self) -> Iterator[Tuple[Tuple, TRow]]:
return (
(self.primary_key(row), row)
for (token, partition) in self.data.items()
for (clustering_key, row) in sorted(
partition.items(), key=lambda ck_and_row: ck_and_row[0]
)
)
[docs]
def get_random(self) -> Optional[TRow]:
return random.choice([row for (pk, row) in self.iter_all()])
[docs]
class InMemoryCqlRunner:
def __init__(self):
self._contents = Table(ContentRow)
self._content_indexes = defaultdict(lambda: defaultdict(set))
self._skipped_contents = Table(ContentRow)
self._skipped_content_indexes = defaultdict(lambda: defaultdict(set))
self._directories = Table(DirectoryRow)
self._directory_entries = Table(DirectoryEntryRow)
self._revisions = Table(RevisionRow)
self._revision_parents = Table(RevisionParentRow)
self._releases = Table(ReleaseRow)
self._snapshots = Table(SnapshotRow)
self._snapshot_branches = Table(SnapshotBranchRow)
self._origins = Table(OriginRow)
self._origin_visits = Table(OriginVisitRow)
self._origin_visit_statuses = Table(OriginVisitStatusRow)
self._metadata_authorities = Table(MetadataAuthorityRow)
self._metadata_fetchers = Table(MetadataFetcherRow)
self._raw_extrinsic_metadata = Table(RawExtrinsicMetadataRow)
self._raw_extrinsic_metadata_by_id = Table(RawExtrinsicMetadataByIdRow)
self._extid = Table(ExtIDRow)
self._object_references = {}
self._object_references_tables_lock = threading.Lock()
self._object_references_tables = Table(ObjectReferencesTableRow)
self._stat_counters = defaultdict(int)
def __getstate__(self):
"""Overrides default :meth:`__getstate__` to exclude the lock, because
:file:`migrate_extrinsic_metadata/test_debian.py` needs this object to be deepcopiable
"""
try:
state = super().__getstate__()
except AttributeError:
# Python <3.10 does provide a default __getstate__ implementation.
state = self.__dict__.copy()
state.pop("_object_references_tables_lock", None)
return state
def _get_token_range(
self, table: Table[TRow], start: int, end: int, limit: int
) -> Iterator[Tuple[int, TRow]]:
matches = [
(token, row)
for (token, partition) in table.data.items()
for (clustering_key, row) in partition.items()
if start <= token <= end
]
matches.sort()
return iter(matches[0:limit])
[docs]
def increment_counter(self, object_type: str, nb: int):
self._stat_counters[object_type] += nb
[docs]
def stat_counters(self) -> Iterable[ObjectCountRow]:
for object_type, count in self._stat_counters.items():
yield ObjectCountRow(partition_key=0, object_type=object_type, count=count)
##########################
# 'content' table
##########################
def _content_add_finalize(self, content: ContentRow) -> None:
self._contents.insert(content)
self.increment_counter("content", 1)
[docs]
def content_add_prepare(self, content: ContentRow):
finalizer = functools.partial(self._content_add_finalize, content)
return (self._contents.token(self._contents.partition_key(content)), finalizer)
[docs]
def content_get_from_pk(
self, content_hashes: Dict[str, bytes]
) -> Optional[ContentRow]:
primary_key = self._contents.primary_key_from_dict(content_hashes)
return self._contents.get_from_primary_key(primary_key)
[docs]
def content_get_from_tokens(self, tokens: List[int]) -> Iterable[ContentRow]:
return itertools.chain.from_iterable(map(self._contents.get_from_token, tokens))
[docs]
def content_get_random(self) -> Optional[ContentRow]:
return self._contents.get_random()
[docs]
def content_get_token_range(
self,
start: int,
end: int,
limit: int,
) -> Iterable[Tuple[int, ContentRow]]:
return self._get_token_range(self._contents, start, end, limit)
[docs]
def content_missing_from_all_hashes(
self, contents_hashes: List[Dict[str, bytes]]
) -> Iterator[Dict[str, bytes]]:
for content_hashes in contents_hashes:
if not self.content_get_from_pk(content_hashes):
yield content_hashes
[docs]
def content_delete(self, content_hashes: TotalHashDict) -> None:
self._contents.delete(
lambda row: all(
getattr(row, k) == content_hashes[k] # type: ignore[literal-required]
for k in HASH_ALGORITHMS
)
)
##########################
# 'content_by_*' tables
##########################
[docs]
def content_missing_by_sha1_git(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if id_ not in self._content_indexes["sha1_git"]:
missing.append(id_)
return missing
[docs]
def content_index_add_one(self, algo: str, content: Content, token: int) -> None:
self._content_indexes[algo][content.get_hash(algo)].add(token)
[docs]
def content_get_tokens_from_single_algo(
self, algo: str, hashes: List[bytes]
) -> Iterable[int]:
for hash_ in hashes:
yield from self._content_indexes[algo][hash_]
##########################
# 'skipped_content' table
##########################
def _skipped_content_add_finalize(self, content: SkippedContentRow) -> None:
self._skipped_contents.insert(content)
self.increment_counter("skipped_content", 1)
[docs]
def skipped_content_add_prepare(self, content: SkippedContentRow):
finalizer = functools.partial(self._skipped_content_add_finalize, content)
return (
self._skipped_contents.token(self._contents.partition_key(content)),
finalizer,
)
[docs]
def skipped_content_get_from_pk(
self, content_hashes: Dict[str, bytes]
) -> Optional[SkippedContentRow]:
primary_key = self._skipped_contents.primary_key_from_dict(content_hashes)
return self._skipped_contents.get_from_primary_key(primary_key)
[docs]
def skipped_content_get_from_token(self, token: int) -> Iterable[SkippedContentRow]:
return self._skipped_contents.get_from_token(token)
[docs]
def skipped_content_delete(self, content_hashes: TotalHashDict) -> None:
self._skipped_contents.delete(
lambda row: all(
getattr(row, k) == content_hashes[k] # type: ignore[literal-required]
for k in HASH_ALGORITHMS
)
)
##########################
# 'skipped_content_by_*' tables
##########################
[docs]
def skipped_content_index_add_one(
self, algo: str, content: SkippedContent, token: int
) -> None:
self._skipped_content_indexes[algo][content.get_hash(algo)].add(token)
[docs]
def skipped_content_get_tokens_from_single_hash(
self, algo: str, hash_: bytes
) -> Iterable[int]:
return self._skipped_content_indexes[algo][hash_]
##########################
# 'directory' table
##########################
[docs]
def directory_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._directories.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
[docs]
def directory_add_one(self, directory: DirectoryRow) -> None:
self._directories.insert(directory)
self.increment_counter("directory", 1)
[docs]
def directory_get_random(self) -> Optional[DirectoryRow]:
return self._directories.get_random()
[docs]
def directory_get(self, directory_ids: List[Sha1Git]) -> Iterable[DirectoryRow]:
for id_ in directory_ids:
row = self._directories.get_from_primary_key((id_,))
if row:
yield row
[docs]
def directory_get_token_range(
self,
start: int,
end: int,
limit: int,
) -> Iterable[Tuple[int, DirectoryRow]]:
return self._get_token_range(self._directories, start, end, limit)
[docs]
def directory_delete(self, directory_id: Sha1Git) -> None:
self._directories.delete(lambda row: row.id == directory_id)
##########################
# 'directory_entry' table
##########################
[docs]
def directory_entry_add_one(self, entry: DirectoryEntryRow) -> None:
self._directory_entries.insert(entry)
[docs]
def directory_entry_get(
self, directory_ids: List[Sha1Git]
) -> Iterable[DirectoryEntryRow]:
for id_ in directory_ids:
yield from self._directory_entries.get_from_partition_key((id_,))
[docs]
def directory_entry_get_from_name(
self, directory_id: Sha1Git, from_: bytes, limit: int
) -> Iterable[DirectoryEntryRow]:
# Get all entries
entries = self._directory_entries.get_from_partition_key((directory_id,))
# Filter out the ones before from_
entries = itertools.dropwhile(lambda entry: entry.name < from_, entries)
# Apply limit
return itertools.islice(entries, limit)
[docs]
def directory_entry_delete(self, directory_id: Sha1Git) -> None:
self._directory_entries.delete(lambda row: row.directory_id == directory_id)
##########################
# 'revision' table
##########################
[docs]
def revision_missing(self, ids: List[bytes]) -> Iterable[bytes]:
missing = []
for id_ in ids:
if self._revisions.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
[docs]
def revision_add_one(self, revision: RevisionRow) -> None:
self._revisions.insert(revision)
self.increment_counter("revision", 1)
[docs]
def revision_get_ids(self, revision_ids) -> Iterable[int]:
for id_ in revision_ids:
if self._revisions.get_from_primary_key((id_,)) is not None:
yield id_
[docs]
def revision_get(
self, revision_ids: List[Sha1Git], ignore_displayname: bool = False
) -> Iterable[RevisionRow]:
for id_ in revision_ids:
row = self._revisions.get_from_primary_key((id_,))
if row:
yield row
[docs]
def revision_get_token_range(
self,
start: int,
end: int,
limit: int,
) -> Iterable[Tuple[int, RevisionRow]]:
return self._get_token_range(self._revisions, start, end, limit)
[docs]
def revision_get_random(self) -> Optional[RevisionRow]:
return self._revisions.get_random()
[docs]
def revision_delete(self, revision_id: Sha1Git) -> None:
self._revisions.delete(lambda row: row.id == revision_id)
##########################
# 'revision_parent' table
##########################
[docs]
def revision_parent_add_one(self, revision_parent: RevisionParentRow) -> None:
self._revision_parents.insert(revision_parent)
[docs]
def revision_parent_get(self, revision_id: Sha1Git) -> Iterable[bytes]:
for parent in self._revision_parents.get_from_partition_key((revision_id,)):
yield parent.parent_id
[docs]
def revision_parent_delete(self, revision_id: Sha1Git) -> None:
self._revision_parents.delete(lambda row: row.id == revision_id)
##########################
# 'release' table
##########################
[docs]
def release_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._releases.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
[docs]
def release_add_one(self, release: ReleaseRow) -> None:
self._releases.insert(release)
self.increment_counter("release", 1)
[docs]
def release_get(
self, release_ids: List[str], ignore_displayname: bool = False
) -> Iterable[ReleaseRow]:
for id_ in release_ids:
row = self._releases.get_from_primary_key((id_,))
if row:
yield row
[docs]
def release_get_token_range(
self,
start: int,
end: int,
limit: int,
) -> Iterable[Tuple[int, ReleaseRow]]:
return self._get_token_range(self._releases, start, end, limit)
[docs]
def release_get_random(self) -> Optional[ReleaseRow]:
return self._releases.get_random()
[docs]
def release_delete(self, release_id: Sha1Git) -> None:
self._releases.delete(lambda row: row.id == release_id)
##########################
# 'snapshot' table
##########################
[docs]
def snapshot_missing(self, ids: List[bytes]) -> List[bytes]:
missing = []
for id_ in ids:
if self._snapshots.get_from_primary_key((id_,)) is None:
missing.append(id_)
return missing
[docs]
def snapshot_add_one(self, snapshot: SnapshotRow) -> None:
self._snapshots.insert(snapshot)
self.increment_counter("snapshot", 1)
[docs]
def snapshot_get_token_range(
self,
start: int,
end: int,
limit: int,
) -> Iterable[Tuple[int, SnapshotRow]]:
return self._get_token_range(self._snapshots, start, end, limit)
[docs]
def snapshot_get_random(self) -> Optional[SnapshotRow]:
return self._snapshots.get_random()
[docs]
def snapshot_branch_get_from_name(
self, snapshot_id: Sha1Git, from_: bytes, limit: int
) -> Iterable[SnapshotBranchRow]:
return self.snapshot_branch_get(snapshot_id=snapshot_id, from_=from_, limit=1)
[docs]
def snapshot_delete(self, snapshot_id: Sha1Git) -> None:
self._snapshots.delete(lambda row: row.id == snapshot_id)
##########################
# 'snapshot_branch' table
##########################
[docs]
def snapshot_branch_add_one(self, branch: SnapshotBranchRow) -> None:
self._snapshot_branches.insert(branch)
[docs]
def snapshot_count_branches(
self,
snapshot_id: Sha1Git,
branch_name_exclude_prefix: Optional[bytes] = None,
) -> Dict[Optional[str], int]:
"""Returns a dictionary from type names to the number of branches
of that type."""
counts: Dict[Optional[str], int] = defaultdict(int)
for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)):
if branch_name_exclude_prefix and branch.name.startswith(
branch_name_exclude_prefix
):
continue
if branch.target_type is None:
target_type = None
else:
target_type = branch.target_type
counts[target_type] += 1
return counts
[docs]
def snapshot_branch_get(
self,
snapshot_id: Sha1Git,
from_: bytes,
limit: int,
branch_name_exclude_prefix: Optional[bytes] = None,
) -> Iterable[SnapshotBranchRow]:
count = 0
for branch in self._snapshot_branches.get_from_partition_key((snapshot_id,)):
prefix = branch_name_exclude_prefix
if branch.name >= from_ and (
prefix is None or not branch.name.startswith(prefix)
):
count += 1
yield branch
if count >= limit:
break
[docs]
def snapshot_branch_delete(self, snapshot_id: Sha1Git) -> None:
self._snapshot_branches.delete(lambda row: row.snapshot_id == snapshot_id)
##########################
# 'origin' table
##########################
[docs]
def origin_add_one(self, origin: OriginRow) -> None:
self._origins.insert(origin)
self.increment_counter("origin", 1)
[docs]
def origin_get_by_sha1(self, sha1: bytes) -> Iterable[OriginRow]:
return self._origins.get_from_partition_key((sha1,))
[docs]
def origin_get_by_url(self, url: str) -> Iterable[OriginRow]:
return self.origin_get_by_sha1(origin_url_to_sha1(url))
[docs]
def origin_list(
self, start_token: int, limit: int
) -> Iterable[Tuple[int, OriginRow]]:
"""Returns an iterable of (token, origin)"""
matches = [
(token, row)
for (token, partition) in self._origins.data.items()
for (clustering_key, row) in partition.items()
if token >= start_token
]
matches.sort()
return matches[0:limit]
[docs]
def origin_iter_all(self) -> Iterable[OriginRow]:
return (
row
for (token, partition) in self._origins.data.items()
for (clustering_key, row) in partition.items()
)
[docs]
def origin_bump_next_visit_id(self, origin_url: str, visit_id: int) -> None:
origin = list(self.origin_get_by_url(origin_url))[0]
origin.next_visit_id = max(origin.next_visit_id, visit_id + 1)
[docs]
def origin_generate_unique_visit_id(self, origin_url: str) -> int:
origin = list(self.origin_get_by_url(origin_url))[0]
visit_id = origin.next_visit_id
origin.next_visit_id += 1
return visit_id
[docs]
def origin_delete(self, sha1: bytes) -> None:
self._origins.delete(lambda row: row.sha1 == sha1)
##########################
# 'origin_visit' table
##########################
[docs]
def origin_visit_get(
self,
origin_url: str,
last_visit: Optional[int],
limit: int,
order: ListOrder,
) -> Iterable[OriginVisitRow]:
visits = list(self._origin_visits.get_from_partition_key((origin_url,)))
if last_visit is not None:
if order == ListOrder.ASC:
visits = [v for v in visits if v.visit > last_visit]
else:
visits = [v for v in visits if v.visit < last_visit]
visits.sort(key=lambda v: v.visit, reverse=order == ListOrder.DESC)
visits = visits[0:limit]
return visits
[docs]
def origin_visit_add_one(self, visit: OriginVisitRow) -> None:
self._origin_visits.insert(visit)
self.increment_counter("origin_visit", 1)
[docs]
def origin_visit_get_one(
self, origin_url: str, visit_id: int
) -> Optional[OriginVisitRow]:
return self._origin_visits.get_from_primary_key((origin_url, visit_id))
[docs]
def origin_visit_iter_all(self, origin_url: str) -> Iterable[OriginVisitRow]:
return reversed(list(self._origin_visits.get_from_partition_key((origin_url,))))
[docs]
def origin_visit_iter(self, start_token: int) -> Iterator[OriginVisitRow]:
"""Returns all origin visits in order from this token,
and wraps around the token space."""
return (
row
for (token, partition) in self._origin_visits.data.items()
for (clustering_key, row) in partition.items()
)
[docs]
def origin_visit_delete(self, origin_url: str) -> None:
self._origin_visits.delete(lambda row: row.origin == origin_url)
##########################
# 'origin_visit_status' table
##########################
[docs]
def origin_visit_status_get_range(
self,
origin: str,
visit: int,
date_from: Optional[datetime.datetime],
limit: int,
order: ListOrder,
) -> Iterable[OriginVisitStatusRow]:
statuses = list(self.origin_visit_status_get(origin, visit))
if date_from is not None:
if order == ListOrder.ASC:
statuses = [s for s in statuses if s.date >= date_from]
else:
statuses = [s for s in statuses if s.date <= date_from]
statuses.sort(key=lambda s: s.date, reverse=order == ListOrder.DESC)
return statuses[0:limit]
[docs]
def origin_visit_status_get_all_range(
self, origin: str, first_visit: int, last_visit: int
) -> Iterable[OriginVisitStatusRow]:
statuses = [
s
for s in self._origin_visit_statuses.get_from_partition_key((origin,))
if s.visit >= first_visit and s.visit <= last_visit
]
statuses.sort(key=lambda s: (s.visit, s.date))
return statuses
[docs]
def origin_visit_status_add_one(self, visit_update: OriginVisitStatusRow) -> None:
self._origin_visit_statuses.insert(visit_update)
self.increment_counter("origin_visit_status", 1)
[docs]
def origin_visit_status_get_latest(
self,
origin: str,
visit: int,
) -> Optional[OriginVisitStatusRow]:
"""Given an origin visit id, return its latest origin_visit_status"""
return next(self.origin_visit_status_get(origin, visit), None)
[docs]
def origin_visit_status_get(
self,
origin: str,
visit: int,
) -> Iterator[OriginVisitStatusRow]:
"""Return all origin visit statuses for a given visit"""
statuses = [
s
for s in self._origin_visit_statuses.get_from_partition_key((origin,))
if s.visit == visit
]
statuses.sort(key=lambda s: s.date, reverse=True)
return iter(statuses)
[docs]
def origin_snapshot_get_all(self, origin: str) -> Iterator[Sha1Git]:
"""Return all snapshots for a given origin"""
return iter(
{
s.snapshot
for s in self._origin_visit_statuses.get_from_partition_key((origin,))
if s.snapshot is not None
}
)
[docs]
def origin_visit_status_delete(self, origin_url: str) -> None:
self._origin_visit_statuses.delete(lambda row: row.origin == origin_url)
##########################
# 'metadata_authority' table
##########################
##########################
# 'metadata_fetcher' table
##########################
#########################
# 'raw_extrinsic_metadata_by_id' table
#########################
#########################
# 'raw_extrinsic_metadata' table
#########################
#########################
# 'extid' table
#########################
def _extid_add_finalize(self, extid: ExtIDRow) -> None:
self._extid.insert(extid)
self.increment_counter("extid", 1)
[docs]
def extid_add_prepare(self, extid: ExtIDRow):
finalizer = functools.partial(self._extid_add_finalize, extid)
return (self._extid.token(self._extid.partition_key(extid)), finalizer)
[docs]
def extid_index_add_one(self, row: ExtIDByTargetRow) -> None:
pass
[docs]
def extid_delete(
self,
extid_type: str,
extid: bytes,
extid_version: int,
target_type: str,
target: bytes,
) -> None:
self._extid.delete(
lambda row: row.extid_type == extid_type
and row.extid == extid
and row.extid_version == extid_version
and row.target_type == target_type
and row.target == target
)
[docs]
def extid_delete_from_by_target_table(
self, target_type: str, target: bytes
) -> None:
self._extid.delete(
lambda row: row.target_type == target_type and row.target == target
)
[docs]
def extid_get_from_pk(
self,
extid_type: str,
extid: bytes,
extid_version: int,
target: ExtendedSWHID,
) -> Optional[ExtIDRow]:
primary_key = self._extid.primary_key_from_dict(
dict(
extid_type=extid_type,
extid=extid,
extid_version=extid_version,
target_type=target.object_type.value,
target=target.object_id,
)
)
return self._extid.get_from_primary_key(primary_key)
[docs]
def extid_get_from_extid(
self,
extid_type: str,
extid: bytes,
) -> Iterable[ExtIDRow]:
return (
row
for pk, row in self._extid.iter_all()
if row.extid_type == extid_type and row.extid == extid
)
[docs]
def extid_get_from_extid_and_version(
self,
extid_type: str,
extid: bytes,
extid_version: int,
) -> Iterable[ExtIDRow]:
return (
row
for pk, row in self._extid.iter_all()
if row.extid_type == extid_type
and row.extid == extid
and (extid_version is None or row.extid_version == extid_version)
)
def _extid_get_from_target_with_type_and_version(
self,
target_type: str,
target: bytes,
extid_type: str,
extid_version: int,
) -> Iterable[ExtIDRow]:
return (
row
for pk, row in self._extid.iter_all()
if row.target_type == target_type
and row.target == target
and row.extid_version == extid_version
and row.extid_type == extid_type
)
def _extid_get_from_target(
self,
target_type: str,
target: bytes,
) -> Iterable[ExtIDRow]:
return (
row
for pk, row in self._extid.iter_all()
if row.target_type == target_type and row.target == target
)
[docs]
def extid_get_from_target(
self,
target_type: str,
target: bytes,
extid_type: Optional[str] = None,
extid_version: Optional[int] = None,
) -> Iterable[ExtIDRow]:
if (extid_version is not None and extid_type is None) or (
extid_version is None and extid_type is not None
):
raise ValueError("You must provide both extid_type and extid_version")
if extid_type is not None and extid_version is not None:
extids = self._extid_get_from_target_with_type_and_version(
target_type, target, extid_type, extid_version
)
else:
extids = self._extid_get_from_target(target_type, target)
return extids
[docs]
def object_reference_add_concurrent(
self, entries: List[ObjectReferenceRow]
) -> None:
today = datetime.date.today()
for entry in entries:
try:
table = next(
table
for table in self.object_references_list_tables()
if table.start.date() <= today < table.end.date()
)
except StopIteration:
raise NonRetryableException(
"No 'object_references_*' table open for writing."
)
self._object_references[table.name].insert(entry)
[docs]
def object_reference_get(
self, target: Sha1Git, target_type: str, limit: int
) -> Iterable[ObjectReferenceRow]:
return itertools.islice(
itertools.chain.from_iterable(
self._object_references[table.name].get_from_partition_key(
(target_type, target)
)
for table in self.object_references_list_tables()
),
limit,
)
[docs]
def object_references_list_tables(self) -> List[ObjectReferencesTableRow]:
return [row for (pk, row) in self._object_references_tables.iter_all()]
[docs]
def object_references_create_table(
self,
date: Tuple[int, int], # _prepared_insert_statement supports only one arg
) -> Tuple[datetime.date, datetime.date]:
(year, week) = date
# This date is guaranteed to be in week 1 by the ISO standard
in_week1 = datetime.date(year=year, month=1, day=4)
monday_of_week1 = in_week1 + datetime.timedelta(days=-in_week1.weekday())
monday = monday_of_week1 + datetime.timedelta(weeks=week - 1)
next_monday = monday + datetime.timedelta(days=7)
name = "object_references_%04dw%02d" % (year, week)
row = ObjectReferencesTableRow(
pk=0, # always the same value, puts everything in the same Cassandra partition
name=name,
year=year,
week=week,
start=Date(monday), # datetime.date -> cassandra.util.Date
end=Date(next_monday), # ditto
)
self._object_references[name] = Table(ObjectReferenceRow)
self._object_references_tables.insert(row)
return (monday, next_monday)
[docs]
def object_references_drop_table(self, year: int, week: int) -> None:
name = "object_references_%04dw%02d" % (year, week)
self._object_references_tables.delete(lambda row: row.name == name)
del self._object_references[name]
[docs]
class InMemoryStorage(CassandraStorage):
_cql_runner: InMemoryCqlRunner # type: ignore
def __init__(self, journal_writer=None):
self.reset()
self.journal_writer = JournalWriter(journal_writer)
self._allow_overwrite = False
self._directory_entries_insert_algo = "one-by-one"
[docs]
def reset(self):
self._cql_runner = InMemoryCqlRunner()
self.objstorage = ObjStorage(self, {"cls": "memory"})
[docs]
def check_config(self, *, check_write: bool) -> bool:
return True