Source code for swh.graph.pytest_plugin

# Copyright (C) 2019-2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import contextlib
import logging
import multiprocessing
import socket
import subprocess
import threading

from aiohttp.test_utils import TestClient, TestServer, loop_context
import grpc
import pytest

from swh.graph.example_dataset import DATASET_DIR
from swh.graph.grpc.swhgraph_pb2_grpc import TraversalServiceStub
from swh.graph.grpc_server import ExecutableNotFound
from swh.graph.http_client import RemoteGraphClient
from swh.graph.http_naive_client import NaiveClient

logger = logging.getLogger(__name__)


[docs] class GraphServerProcess(multiprocessing.Process): def __init__(self, config, *args, **kwargs): self.config = config self.q = multiprocessing.Queue() super().__init__(*args, **kwargs)
[docs] def run(self): # Lazy import to allow debian packaging from swh.graph.http_rpc_server import make_app try: with loop_context() as loop: app = make_app(config=self.config) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") self.q.put( { "server_url": url, "rpc_url": app["rpc_url"], "pid": app["local_server"].pid, } ) loop.run_forever() except Exception as e: if isinstance(e, ExecutableNotFound): # hack to add a bit more context and help to the user, # especially when this is used from another swh package... # XXX on py>=3.11 we could use e.add_note() instead e.args = ( *e.args, "This probably means you need to build the rust grpc server " 'for swh-graph. Check the "Minimal setup for tests" section in ' "the rust/README.md file in the swh-graph " "source code directory.", ) logger.exception(e) self.q.put(e)
[docs] def start(self, *args, **kwargs): super().start() self.result = self.q.get()
[docs] class StatsdServer: def __init__(self): self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._sock.bind(("127.0.0.1", 0)) self._sock.settimeout(0.1) (self.host, self.port) = self._sock.getsockname() self._closing = False self._thread = threading.Thread(target=self._listen) self._thread.start() self.datagrams = [] self.new_datagram = threading.Event() """Woken up every time a datagram is added to self.datagrams.""" def _listen(self): while not self._closing: try: (datagram, addr) = self._sock.recvfrom(4096) except TimeoutError: continue self.datagrams.append(datagram) self.new_datagram.set() self._sock.close()
[docs] def close(self): self._closing = True
[docs] @pytest.fixture(scope="session") def graph_statsd_server(): with contextlib.closing(StatsdServer()) as statsd_server: yield statsd_server
[docs] @pytest.fixture(scope="session", params=["rust"]) def graph_grpc_backend_implementation(request): return request.param
[docs] @pytest.fixture(scope="session") def graph_grpc_server_config(graph_grpc_backend_implementation, graph_statsd_server): return { "graph": { "cls": f"local_{graph_grpc_backend_implementation}", "grpc_server": { "path": DATASET_DIR / "compressed/example", "debug": True, "statsd_host": graph_statsd_server.host, "statsd_port": graph_statsd_server.port, }, "http_rpc_server": {"debug": True}, } }
[docs] @pytest.fixture(scope="session") def graph_grpc_server_process(graph_grpc_server_config, graph_statsd_server): server = GraphServerProcess(graph_grpc_server_config) yield server try: server.kill() except AttributeError: # server was never started pass
[docs] @pytest.fixture(scope="session") def graph_grpc_server_started(graph_grpc_server_process): server = graph_grpc_server_process server.start() if isinstance(server.result, Exception): raise server.result yield server server.kill()
[docs] @pytest.fixture(scope="module") def graph_grpc_stub(graph_grpc_server): with grpc.insecure_channel(graph_grpc_server) as channel: stub = TraversalServiceStub(channel) yield stub
[docs] @pytest.fixture(scope="module") def graph_grpc_server(graph_grpc_server_started): yield graph_grpc_server_started.result["rpc_url"]
[docs] @pytest.fixture(scope="module") def remote_graph_client_url(graph_grpc_server_started): yield str(graph_grpc_server_started.result["server_url"])
[docs] @pytest.fixture(scope="module") def remote_graph_client(graph_grpc_server_started): yield RemoteGraphClient(str(graph_grpc_server_started.result["server_url"]))
[docs] @pytest.fixture(scope="module") def naive_graph_client(): def zstdcat(*files): p = subprocess.run(["zstdcat", *files], stdout=subprocess.PIPE) return p.stdout.decode() edges_dataset = DATASET_DIR / "edges" edge_files = edges_dataset.glob("*/*.edges.csv.zst") node_files = edges_dataset.glob("*/*.nodes.csv.zst") nodes = set(zstdcat(*node_files).strip().split("\n")) edge_lines = [line.split() for line in zstdcat(*edge_files).strip().split("\n")] edges = [(src, dst) for src, dst, *_ in edge_lines] for src, dst in edges: nodes.add(src) nodes.add(dst) yield NaiveClient(nodes=list(nodes), edges=edges)
[docs] @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": yield request.getfixturevalue("remote_graph_client") else: yield request.getfixturevalue("naive_graph_client")