# Copyright (C) 2021-2026 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import partial
import logging
from threading import Event, Thread
import time
from typing import (
Callable,
ContextManager,
Dict,
Iterator,
List,
Optional,
Protocol,
Set,
Tuple,
)
import psycopg
from ...exc import ReadOnlyObjStorageError
from .database import Database
logger = logging.getLogger(__name__)
PQ_IS_PYTHON = psycopg.pq.__impl__ == "python"
[docs]
@dataclass
class ShardObjectsCount:
count: int
volume: int
[docs]
class IdleHandler(Thread):
"""Call the `callback` after being idle for `timeout` seconds."""
def __init__(self, name: str, timeout: float, callback: Callable[[], None]):
super().__init__(name=f"IdleHandler-{name}")
self.timeout = timeout
self.callback = callback
self.deadline = time.monotonic() + timeout
self.quiesced = Event()
self.terminated = Event()
[docs]
def quiesce(self):
"""Quiesce the timeout.
This should generally be used via the :func:`quiesce_then_reset` context
manager, which wraps a block of code to quiesce the timeout while the
code runs, then resets the timeout on completion.
"""
self.quiesced.set()
[docs]
def reset(self):
"""Reset the timeout clock.
This should generally be used via the :func:`quiesce_then_reset` context
manager, which wraps a block of code to quiesce the timeout while the
code runs, then resets the timeout on completion.
"""
self.deadline = time.monotonic() + self.timeout
self.quiesced.clear()
[docs]
@contextmanager
def quiesce_then_reset(self):
"""Wrap a block of code to quiesce the timeout while the code runs,
then reset the timeout on completion.
"""
self.quiesce()
yield
self.reset()
[docs]
def join(self, timeout=None):
"""Gracefully terminate the thread."""
self.terminated.set()
# Trigger exit from the main loop by setting the quiesced event
self.quiesced.set()
return super().join(timeout)
[docs]
def run(self):
while True:
# Wait at least 1 second when paused
wait_for = max(self.deadline - time.monotonic(), 1)
quiesced = self.quiesced.wait(timeout=wait_for)
if self.terminated.is_set():
break
if quiesced:
time.sleep(0.1)
continue
if time.monotonic() > self.deadline:
break
if not self.terminated.is_set():
logger.debug("Idle timeout reached, calling idle callback")
self.callback()
[docs]
class ShardIdleTimeoutCallback(Protocol):
"""A function which takes a :class:`RWShard` as `shard` argument, used as
idle timeout callback for :class:`RWShard`."""
def __call__(self, shard: "RWShard") -> None: ...
[docs]
class RWShard(Database):
def __init__(
self,
name: str,
base_dsn: str,
shard_max_size: int,
application_name: Optional[str] = None,
idle_timeout_cb: Optional[ShardIdleTimeoutCallback] = None,
idle_timeout: Optional[float] = 5,
readonly: bool = False,
**kwargs,
):
self._name = name
if application_name is None:
application_name = f"SWH Winery RW Shard {name}"
super().__init__(dsn=base_dsn, application_name=application_name)
self.readonly = readonly
if not self.readonly:
self.create()
self.obj_count = self.total_size()
self.limit = shard_max_size
self.idle_handler: Optional[IdleHandler] = None
self.quiesce_then_reset_idle: Callable[[], ContextManager] = nullcontext
if idle_timeout and idle_timeout_cb:
self.idle_handler = IdleHandler(
name=name,
timeout=idle_timeout,
callback=partial(idle_timeout_cb, shard=self),
)
self.idle_handler.start()
self.quiesce_then_reset_idle = self.idle_handler.quiesce_then_reset
[docs]
def disable_idle_handler(self):
if thread := getattr(self, "idle_handler"):
thread.join()
self.idle_handler = None
self.quiesce_then_reset_idle = nullcontext
@property
def name(self) -> str:
return self._name
@property
def table_name(self) -> str:
return f"shard_{self._name}"
[docs]
def is_full(self) -> bool:
return self.obj_count.volume >= self.limit
[docs]
def create(self) -> None:
with self.pool.connection() as db:
db.execute(
f"CREATE TABLE IF NOT EXISTS {self.table_name} "
"(LIKE shard_template INCLUDING ALL) "
"WITH (autovacuum_enabled = false)"
)
[docs]
def drop(self) -> None:
with self.pool.connection() as db:
db.execute(f"DROP TABLE {self.table_name}")
[docs]
def total_size(self) -> ShardObjectsCount:
"Return the number of entries and their total volume size"
with self.pool.connection() as db, db.cursor() as c:
c.execute(f"SELECT COUNT(*), SUM(LENGTH(content)) FROM {self.table_name}")
result = c.fetchone()
if result is None:
return ShardObjectsCount(0, 0)
else:
return ShardObjectsCount(result[0], result[1] or 0)
[docs]
def add(self, db: psycopg.Connection, obj_id: bytes, content: bytes) -> None:
self.add_batch(db, [(obj_id, content)])
[docs]
def add_batch(
self, db: psycopg.Connection, contents: List[Tuple[bytes, bytes]]
) -> Dict:
"""``contents`` should be pairs of ``(obj_id, content)``"""
if self.readonly:
raise ReadOnlyObjStorageError(
f"Cannot write to shard {self._name}, objstorage is readonly"
)
num_added = 0
num_bytes_added = 0
# insert in consistent order to avoid deadlocks
contents = list(
sorted(contents, key=lambda obj_id_and_content: obj_id_and_content[0])
)
with self.quiesce_then_reset_idle():
with db.cursor() as cur:
cur.executemany(
f"""
INSERT INTO {self.table_name} (key, content)
VALUES (%s, %s)
ON CONFLICT (key) DO NOTHING
RETURNING key
""",
contents,
returning=True,
)
inserted: Set[bytes] = set()
for _ in cur.results():
for (obj_id,) in cur:
inserted.add(obj_id)
for obj_id, content in contents:
try:
inserted.remove(obj_id)
except KeyError:
# already in the table
pass
else:
# new content
num_added += 1
num_bytes_added += len(content)
self.obj_count.count += num_added
self.obj_count.volume += num_bytes_added
return {
"object:add": num_added,
"object:add:bytes": num_bytes_added,
}
[docs]
def get(self, obj_id: bytes) -> Optional[bytes]:
with self.pool.connection() as db, db.cursor() as c:
c.execute(
f"SELECT content FROM {self.table_name} WHERE key = %s",
(obj_id,),
binary=True,
)
if c.rowcount == 0:
return None
else:
return c.fetchone()[0]
[docs]
def delete(self, obj_id: bytes) -> None:
if self.readonly:
raise ReadOnlyObjStorageError(
f"Cannot drop {obj_id.hex()} from shard {self._name}, objstorage is readonly"
)
with self.pool.connection() as db, db.cursor() as c:
c.execute(f"DELETE FROM {self.table_name} WHERE key = %s", (obj_id,))
if c.rowcount == 0:
raise KeyError(obj_id)
[docs]
def all(self) -> Iterator[Tuple[bytes, bytes]]:
with self.pool.connection() as db, db.cursor() as c:
with c.copy(
f"COPY {self.table_name} (key, content) TO STDOUT (FORMAT BINARY)"
) as copy:
copy.set_types(["bytea", "bytea"])
if PQ_IS_PYTHON:
# pure python implem of psycopg.pq does return memoryview
# objects, not bytes
for x, y in copy.rows():
yield (x.tobytes(), y.tobytes())
else:
# C/Binary versions of psycopg.pq (aka psycopg[c|binary]) do
# return bytes objects directly
yield from copy.rows()
[docs]
def count(self) -> int:
with self.pool.connection() as db, db.cursor() as c:
c.execute(f"SELECT COUNT(*) FROM {self.table_name}")
return c.fetchone()[0]