# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import json
import logging
from swh.core.api import RPCClient
logger = logging.getLogger(__name__)
[docs]
class GraphAPIError(Exception):
"""Graph API Error"""
def __str__(self):
return """An unexpected error occurred
in the Graph backend: {}""".format(
self.args
)
[docs]
class GraphArgumentException(Exception):
def __init__(self, *args, response=None):
super().__init__(*args)
self.response = response
[docs]
class RemoteGraphClient(RPCClient):
"""Client to the Software Heritage Graph."""
def __init__(self, url, timeout=None):
super().__init__(api_exception=GraphAPIError, url=url, timeout=timeout)
try:
stats = self.stats()
except GraphArgumentException as e:
if e.response.status_code == 404:
raise ValueError(
"URL is incorrect (got 404 while trying to retrieve stats)"
) from None
raise
if "num_nodes" not in stats:
raise ValueError("stats returned unexpected results (no `num_nodes` entry)")
if "export_started_at" in stats:
from datetime import datetime, timezone
logger.debug(
"Graph export started at %s (%d nodes)",
datetime.fromtimestamp(
int(stats["export_started_at"]), tz=timezone.utc
).isoformat(),
stats["num_nodes"],
)
[docs]
def raw_verb_lines(self, verb, endpoint, **kwargs):
response = self.raw_verb(verb, endpoint, stream=True, **kwargs)
self.raise_for_status(response)
for line in response.iter_lines():
content = line.decode().lstrip("\n")
if content:
yield content
[docs]
def get_lines(self, endpoint, **kwargs):
yield from self.raw_verb_lines("get", endpoint, **kwargs)
[docs]
def raise_for_status(self, response) -> None:
if response.status_code // 100 == 4:
raise GraphArgumentException(
response.content.decode("ascii"), response=response
)
super().raise_for_status(response)
# Web API endpoints
[docs]
def stats(self):
return self._get("stats")
[docs]
def leaves(
self,
src,
edges="*",
direction="forward",
max_edges=0,
return_types="*",
max_matching_nodes=0,
):
return self.get_lines(
"leaves/{}".format(src),
params={
"edges": edges,
"direction": direction,
"max_edges": max_edges,
"return_types": return_types,
"max_matching_nodes": max_matching_nodes,
},
)
[docs]
def neighbors(
self,
src,
edges="*",
direction="forward",
max_edges=0,
return_types="*",
max_matching_nodes=0,
):
return self.get_lines(
"neighbors/{}".format(src),
params={
"edges": edges,
"direction": direction,
"max_edges": max_edges,
"return_types": return_types,
"max_matching_nodes": max_matching_nodes,
},
)
[docs]
def visit_nodes(
self,
src,
edges="*",
direction="forward",
max_edges=0,
return_types="*",
max_matching_nodes=0,
):
return self.get_lines(
"visit/nodes/{}".format(src),
params={
"edges": edges,
"direction": direction,
"max_edges": max_edges,
"return_types": return_types,
"max_matching_nodes": max_matching_nodes,
},
)
[docs]
def visit_edges(self, src, edges="*", direction="forward", max_edges=0):
for edge in self.get_lines(
"visit/edges/{}".format(src),
params={"edges": edges, "direction": direction, "max_edges": max_edges},
):
yield tuple(edge.split())
[docs]
def visit_paths(self, src, edges="*", direction="forward", max_edges=0):
def decode_path_wrapper(it):
for e in it:
yield json.loads(e)
return decode_path_wrapper(
self.get_lines(
"visit/paths/{}".format(src),
params={"edges": edges, "direction": direction, "max_edges": max_edges},
)
)
[docs]
def walk(
self, src, dst, edges="*", traversal="dfs", direction="forward", limit=None
):
endpoint = "walk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
params={
"edges": edges,
"traversal": traversal,
"direction": direction,
"limit": limit,
},
)
[docs]
def random_walk(
self, src, dst, edges="*", direction="forward", limit=None, return_types="*"
):
endpoint = "randomwalk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
params={
"edges": edges,
"direction": direction,
"limit": limit,
"return_types": return_types,
},
)
[docs]
def count_leaves(self, src, edges="*", direction="forward", max_matching_nodes=0):
return self._get(
"leaves/count/{}".format(src),
params={
"edges": edges,
"direction": direction,
"max_matching_nodes": max_matching_nodes,
},
)
[docs]
def count_neighbors(self, src, edges="*", direction="forward"):
return self._get(
"neighbors/count/{}".format(src),
params={"edges": edges, "direction": direction},
)
[docs]
def count_visit_nodes(
self, src, edges="*", direction="forward", max_matching_nodes=0
):
return self._get(
"visit/nodes/count/{}".format(src),
params={
"edges": edges,
"direction": direction,
"max_matching_nodes": max_matching_nodes,
},
)