Source code for swh.graph.pytest_plugin
# Copyright (C) 2019-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import contextlib
import logging
import multiprocessing
import socket
import subprocess
import threading
from aiohttp.test_utils import TestClient, TestServer, loop_context
import grpc
import pytest
from swh.graph.example_dataset import DATASET_DIR
from swh.graph.grpc.swhgraph_pb2_grpc import TraversalServiceStub
from swh.graph.grpc_server import ExecutableNotFound
from swh.graph.http_client import RemoteGraphClient
from swh.graph.http_naive_client import NaiveClient
logger = logging.getLogger(__name__)
[docs]
class GraphServerProcess(multiprocessing.Process):
def __init__(self, config, *args, **kwargs):
self.config = config
self.q = multiprocessing.Queue()
super().__init__(*args, **kwargs)
[docs]
def run(self):
# Lazy import to allow debian packaging
from swh.graph.http_rpc_server import make_app
try:
with loop_context() as loop:
app = make_app(config=self.config)
client = TestClient(TestServer(app), loop=loop)
loop.run_until_complete(client.start_server())
url = client.make_url("/graph/")
self.q.put(
{
"server_url": url,
"rpc_url": app["rpc_url"],
"pid": app["local_server"].pid,
}
)
loop.run_forever()
except Exception as e:
if isinstance(e, ExecutableNotFound):
# hack to add a bit more context and help to the user,
# especially when this is used from another swh package...
# XXX on py>=3.11 we could use e.add_note() instead
e.args = (
*e.args,
"This probably means you need to build the rust grpc server "
'for swh-graph. Check the "Minimal setup for tests" section in '
"the rust/README.md file in the swh-graph "
"source code directory.",
)
logger.exception(e)
self.q.put(e)
[docs]
def start(self, *args, **kwargs):
super().start()
self.result = self.q.get()
[docs]
class StatsdServer:
def __init__(self):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self._sock.bind(("127.0.0.1", 0))
self._sock.settimeout(0.1)
(self.host, self.port) = self._sock.getsockname()
self._closing = False
self._thread = threading.Thread(target=self._listen)
self._thread.start()
self.datagrams = []
self.new_datagram = threading.Event()
"""Woken up every time a datagram is added to self.datagrams."""
def _listen(self):
while not self._closing:
try:
(datagram, addr) = self._sock.recvfrom(4096)
except TimeoutError:
continue
self.datagrams.append(datagram)
self.new_datagram.set()
self._sock.close()
[docs]
def close(self):
self._closing = True
[docs]
@pytest.fixture(scope="session")
def graph_statsd_server():
with contextlib.closing(StatsdServer()) as statsd_server:
yield statsd_server
[docs]
@pytest.fixture(scope="session", params=["rust"])
def graph_grpc_backend_implementation(request):
return request.param
[docs]
@pytest.fixture(scope="session")
def graph_grpc_server_config(graph_grpc_backend_implementation, graph_statsd_server):
return {
"graph": {
"cls": f"local_{graph_grpc_backend_implementation}",
"grpc_server": {
"path": DATASET_DIR / "compressed/example",
"debug": True,
"statsd_host": graph_statsd_server.host,
"statsd_port": graph_statsd_server.port,
},
"http_rpc_server": {"debug": True},
}
}
[docs]
@pytest.fixture(scope="session")
def graph_grpc_server_process(graph_grpc_server_config, graph_statsd_server):
server = GraphServerProcess(graph_grpc_server_config)
yield server
try:
server.kill()
except AttributeError:
# server was never started
pass
[docs]
@pytest.fixture(scope="session")
def graph_grpc_server_started(graph_grpc_server_process):
server = graph_grpc_server_process
server.start()
if isinstance(server.result, Exception):
raise server.result
yield server
server.kill()
[docs]
@pytest.fixture(scope="module")
def graph_grpc_stub(graph_grpc_server):
with grpc.insecure_channel(graph_grpc_server) as channel:
stub = TraversalServiceStub(channel)
yield stub
[docs]
@pytest.fixture(scope="module")
def graph_grpc_server(graph_grpc_server_started):
yield graph_grpc_server_started.result["rpc_url"]
[docs]
@pytest.fixture(scope="module")
def remote_graph_client_url(graph_grpc_server_started):
yield str(graph_grpc_server_started.result["server_url"])
[docs]
@pytest.fixture(scope="module")
def remote_graph_client(graph_grpc_server_started):
yield RemoteGraphClient(str(graph_grpc_server_started.result["server_url"]))
[docs]
@pytest.fixture(scope="module")
def naive_graph_client():
def zstdcat(*files):
p = subprocess.run(["zstdcat", *files], stdout=subprocess.PIPE)
return p.stdout.decode()
edges_dataset = DATASET_DIR / "edges"
edge_files = edges_dataset.glob("*/*.edges.csv.zst")
node_files = edges_dataset.glob("*/*.nodes.csv.zst")
nodes = set(zstdcat(*node_files).strip().split("\n"))
edge_lines = [line.split() for line in zstdcat(*edge_files).strip().split("\n")]
edges = [(src, dst) for src, dst, *_ in edge_lines]
for src, dst in edges:
nodes.add(src)
nodes.add(dst)
yield NaiveClient(nodes=list(nodes), edges=edges)
[docs]
@pytest.fixture(scope="module", params=["remote", "naive"])
def graph_client(request):
if request.param == "remote":
yield request.getfixturevalue("remote_graph_client")
else:
yield request.getfixturevalue("naive_graph_client")