Source code for swh.graph.http_rpc_server

# Copyright (C) 2019-2023  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

"""
A proxy HTTP server for swh-graph, talking to the Rust code via the gRPC API.
"""

import json
import logging
import os
from typing import Optional

import aiohttp.test_utils
import aiohttp.web
from google.protobuf import json_format
from google.protobuf.field_mask_pb2 import FieldMask
import grpc

from swh.core.config import read as config_read
from swh.graph.grpc.swhgraph_pb2 import (
    GetNodeRequest,
    NodeFilter,
    StatsRequest,
    TraversalRequest,
)
from swh.graph.grpc.swhgraph_pb2_grpc import TraversalServiceStub
from swh.graph.grpc_server import spawn_rust_grpc_server, stop_grpc_server
from swh.model.swhids import EXTENDED_SWHID_TYPES

try:
    from contextlib import asynccontextmanager
except ImportError:
    # Compatibility with 3.6 backport
    from async_generator import asynccontextmanager  # type: ignore


# maximum number of retries for random walks
RANDOM_RETRIES = 10  # TODO make this configurable via rpc-serve configuration

logger = logging.getLogger(__name__)


async def _aiorpcerror_middleware(app, handler):
    async def middleware_handler(request):
        try:
            return await handler(request)
        except grpc.aio.AioRpcError as e:
            # The default error handler of the RPC framework tries to serialize this
            # with msgpack; which for some unknown reason causes it to raise
            # ValueError("recursion limit exceeded") with a lot of context, causing
            # Sentry to be overflowed with gigabytes of logs (160KB per event, with
            # potentially hundreds of thousands of events per day).
            # Instead, we simply serialize the exception to a string.
            # https://sentry.softwareheritage.org/share/issue/d6d4db971e4b47728a6c1dd06cb9b8a5/
            raise aiohttp.web.HTTPServiceUnavailable(text=str(e))

    return middleware_handler


[docs] class GraphServerApp(aiohttp.web.Application): def __init__(self, *args, middlewares=(), **kwargs): middlewares = (_aiorpcerror_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) self.on_startup.append(self._start) self.on_shutdown.append(self._stop) @staticmethod async def _start(app): app["channel"] = grpc.aio.insecure_channel(app["rpc_url"]) await app["channel"].__aenter__() app["rpc_client"] = TraversalServiceStub(app["channel"]) await app["rpc_client"].Stats(StatsRequest(), wait_for_ready=True) @staticmethod async def _stop(app): await app["channel"].__aexit__(None, None, None) if app.get("local_server"): stop_grpc_server(app["local_server"])
[docs] async def index(request): return aiohttp.web.Response( content_type="text/html", body="""<html> <head><title>Software Heritage graph server</title></head> <body> <p>You have reached the <a href="https://www.softwareheritage.org/"> Software Heritage</a> graph API server.</p> <p>See its <a href="https://docs.softwareheritage.org/devel/swh-graph/api.html">API documentation</a> for more information.</p> </body> </html>""", )
[docs] class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.rpc_client: TraversalServiceStub = self.request.app["rpc_client"]
[docs] def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s.upper()
[docs] def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s
[docs] def get_return_types(self): """Validate HTTP query parameter 'return types', i.e, a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) # if the user puts a star, # then we filter nothing, we don't need the other information if "*" in s: return "*" else: return s
[docs] def get_max_matching_nodes(self): """Validate HTTP query parameter `max_matching_nodes`, i.e., number of results""" s = self.request.query.get("max_matching_nodes", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest( text=f"invalid max_matching_nodes value: {s}" )
[docs] def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}")
[docs] async def check_swhid(self, swhid): """Validate that the given SWHID exists in the graph""" try: await self.rpc_client.GetNode( GetNodeRequest(swhid=swhid, mask=FieldMask(paths=["swhid"])) ) except grpc.aio.AioRpcError as e: if e.code() in ( grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND, # Not used by Java backend, always "invalid" ): raise aiohttp.web.HTTPBadRequest(text=str(e.details()))
[docs] class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain"
[docs] @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof()
[docs] async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: self._buf = [] try: await self.stream_response() finally: await self._flush_buffer() return self.response_stream
[docs] async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass
[docs] async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError
[docs] async def stream_line(self, line): """Write a line in the response stream.""" self._buf.append(line) if len(self._buf) > 100: await self._flush_buffer()
async def _flush_buffer(self): await self.response_stream.write("\n".join(self._buf).encode() + b"\n") self._buf = []
[docs] class StatsView(GraphView): """View showing some statistics on the graph"""
[docs] async def get(self): res = await self.rpc_client.Stats(StatsRequest()) stats = json_format.MessageToDict( res, including_default_value_fields=True, preserving_proto_field_name=True ) # Int64 fields are serialized as strings by default. for descriptor in res.DESCRIPTOR.fields: if descriptor.type == descriptor.TYPE_INT64: try: stats[descriptor.name] = int(stats[descriptor.name]) except KeyError: pass json_body = json.dumps(stats, indent=4, sort_keys=True) return aiohttp.web.Response(body=json_body, content_type="application/json")
[docs] class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals"""
[docs] async def prepare_response(self): src = self.request.match_info["src"] self.traversal_request = TraversalRequest( src=[src], edges=self.get_edges(), direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), max_matching_nodes=self.get_max_matching_nodes(), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() await self.check_swhid(src) self.configure_request() self.nodes_stream = self.rpc_client.Traverse( self.traversal_request, ) # Force gRPC to query the server and fetch the first nodes; so errors # are raised early, so we can return HTTP 503 before HTTP 200 await self.nodes_stream.wait_for_connection()
[docs] def configure_request(self): pass
[docs] async def stream_response(self): async for node in self.nodes_stream: await self.stream_line(node.swhid)
[docs] class LeavesView(SimpleTraversalView):
[docs] def configure_request(self): self.traversal_request.return_nodes.max_traversal_successors = 0
[docs] class NeighborsView(SimpleTraversalView):
[docs] def configure_request(self): self.traversal_request.min_depth = 1 self.traversal_request.max_depth = 1
[docs] class VisitNodesView(SimpleTraversalView): pass
[docs] class VisitEdgesView(SimpleTraversalView):
[docs] def configure_request(self): self.traversal_request.mask.paths.extend(["successor.swhid"])
# self.traversal_request.return_fields.successor = True
[docs] async def stream_response(self): async for node in self.nodes_stream: for succ in node.successor: await self.stream_line(node.swhid + " " + succ.swhid)
[docs] class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None
[docs] async def get(self): src = self.request.match_info["src"] self.traversal_request = TraversalRequest( src=[src], edges=self.get_edges(), direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), max_matching_nodes=self.get_max_matching_nodes(), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() self.configure_request() res = await self.rpc_client.CountNodes(self.traversal_request) return aiohttp.web.Response( body=str(res.count), content_type="application/json" )
[docs] def configure_request(self): pass
[docs] class CountNeighborsView(CountView):
[docs] def configure_request(self): self.traversal_request.min_depth = 1 self.traversal_request.max_depth = 1
[docs] class CountLeavesView(CountView):
[docs] def configure_request(self): self.traversal_request.return_nodes.max_traversal_successors = 0
[docs] class CountVisitNodesView(CountView): pass
[docs] def make_app(config=None): """Create an aiohttp server for the HTTP RPC frontend to the swh-graph API. It may either connect to an existing grpc server (cls="remote") or spawn a local grpc server (cls="local"). ``config`` is expected to be a dict like:: graph: cls: "local" grpc_server: port: 50091 http_rpc_server: debug: true or:: graph: cls: "remote" url: "localhost:50091" http_rpc_server: debug: true See: - :mod:`swh.graph.grpc_server` for more details of the content of the grpc_server section, - :class:`~.GraphServerApp` class for more details of the content of the http_rpc_server section. """ if config is None: config = {} if "graph" not in config: logger.info( "Missing 'graph' configuration; default to a locally spawn" "grpc server listening on 0.0.0.0:50091" ) cfg = {"cls": "local", "grpc_server": {"port": 50091}} else: cfg = config["graph"].copy() cls = cfg.pop("cls") grpc_cfg = cfg.pop("grpc_server", {}) app = GraphServerApp(**cfg.get("http_rpc_server", {})) if cls == "remote": if "url" not in cfg: raise KeyError("Missing 'url' configuration entry in the [graph] section") rpc_url = cfg["url"] elif cls in ("local", "local_rust"): app["local_server"], port = spawn_rust_grpc_server(**grpc_cfg) rpc_url = f"localhost:{port}" else: raise ValueError(f"Unknown swh.graph class cls={cls}") app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["rpc_url"] = rpc_url return app
[docs] def make_app_from_configfile(): """Load configuration and then build application to run""" config_file = os.environ.get("SWH_CONFIG_FILENAME") config = config_read(config_file) return make_app(config=config)