# Copyright (C) 2022-2026 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from hashlib import sha1
import logging
import attr
from google.protobuf.field_mask_pb2 import FieldMask
import grpc
from swh.graph.grpc.swhgraph_pb2 import (
FindPathBetweenRequest,
FindPathToRequest,
NodeFilter,
TraversalRequest,
)
from swh.graph.grpc.swhgraph_pb2_grpc import TraversalServiceStub
from swh.model.cli import swhid_of_file
from swh.model.exceptions import ValidationError
# documentation: https://docs.softwareheritage.org/devel/apidoc/swh.model.swhids.html
from swh.model.swhids import CoreSWHID, ExtendedObjectType, ExtendedSWHID
GRAPH_GRPC_SERVER = "localhost:50091"
logger = logging.getLogger(__name__)
[docs]
def fqswhid_of_traversal(response, verbose):
"""Build the fully qualified SWHID for a gRPC response.
Args:
response: Response from the gRPC server.
verbose: Verbosity.
Returns:
The fully qualified SWHID corresponding to the response from the gRPC
server, as a string.
"""
path_items = []
revision = None
release = None
snapshot = None
origin = ""
shortest_path = iter(response.labeled_node)
target_labeled_node = next(shortest_path)
swhid = CoreSWHID.from_string(target_labeled_node.node.swhid).to_qualified()
core_swhid = target_node_swhid = ExtendedSWHID.from_string(
target_labeled_node.node.swhid
)
for source_labeled_node in shortest_path:
if verbose:
print("Examining node: {target_labeled_node.node.swhid}")
origin = source_labeled_node.node.ori.url if origin == "" else ""
source_node_swhid = ExtendedSWHID.from_string(source_labeled_node.node.swhid)
if target_node_swhid.object_type in (
ExtendedObjectType.CONTENT,
ExtendedObjectType.DIRECTORY,
):
if source_node_swhid.object_type == ExtendedObjectType.DIRECTORY:
if len(target_labeled_node.label) > 0:
pathid = target_labeled_node.label[0].name.decode()
path_items.insert(0, pathid)
if target_node_swhid.object_type == ExtendedObjectType.REVISION:
if revision is None:
revision = target_labeled_node.node.swhid
if target_node_swhid.object_type == ExtendedObjectType.RELEASE:
if release is None:
release = target_labeled_node.node.swhid
if target_node_swhid.object_type == ExtendedObjectType.SNAPSHOT:
snapshot = target_labeled_node.node.swhid
target_labeled_node = source_labeled_node
target_node_swhid = source_node_swhid
visit = snapshot if core_swhid.object_type != ExtendedObjectType.SNAPSHOT else None
if (
core_swhid.object_type == ExtendedObjectType.CONTENT
or core_swhid.object_type == ExtendedObjectType.DIRECTORY
):
anchor = revision or release
path = f"/{'/'.join(path_items)}"
else:
anchor = path = None
fqswhid = attr.evolve(swhid, origin=origin, visit=visit, anchor=anchor, path=path)
return str(fqswhid)
[docs]
def main(
content_swhid,
origin_url,
all_origins,
random_origin,
filename,
graph_grpc_server,
fqswhid,
trace,
):
# Check if content SWHID is valid
try:
CoreSWHID.from_string(content_swhid)
except ValidationError:
print(f"Error: '{content_swhid}' is not a valid SWHID")
return
with grpc.insecure_channel(graph_grpc_server) as channel:
client = TraversalServiceStub(channel)
field_mask_findpath = FieldMask(
paths=[
"labeled_node.node.swhid",
"labeled_node.node.ori.url",
"labeled_node.label.name",
]
)
field_mask_traverse = FieldMask(
paths=[
"node.swhid",
"node.ori.url",
"node.successor.swhid",
"node.successor.label.name",
]
)
try:
if filename:
content_swhid = str(swhid_of_file(filename))
# Traversal request: get all origins
if all_origins:
random_origin = False
response = client.Traverse(
TraversalRequest(
src=[content_swhid],
edges="cnt:dir,dir:dir,dir:rev,rev:rev,rev:snp,rev:rel,rel:snp,snp:ori",
direction="BACKWARD",
return_nodes=NodeFilter(types="ori"),
mask=field_mask_traverse,
)
)
for node in response:
if fqswhid:
response = client.FindPathBetween(
FindPathBetweenRequest(
src=[content_swhid],
dst=[node.swhid],
direction="BACKWARD",
mask=field_mask_findpath,
)
)
print(fqswhid_of_traversal(response, verbose=trace))
else:
print(node.ori.url)
# Traversal request to a (random) origin
if random_origin:
response = client.FindPathTo(
FindPathToRequest(
src=[content_swhid],
target=NodeFilter(types="ori"),
direction="BACKWARD",
mask=field_mask_findpath,
)
)
if fqswhid:
print(fqswhid_of_traversal(response, verbose=trace))
else:
for labeled_node in response.labeled_node.node:
print(labeled_node.node.ori.url)
# Traversal request to a given origin URL
if origin_url:
response = client.FindPathBetween(
FindPathBetweenRequest(
src=[content_swhid],
dst=[
str(
ExtendedSWHID(
object_type=ExtendedObjectType.ORIGIN,
object_id=bytes.fromhex(
sha1(bytes(origin_url, "UTF-8")).hexdigest()
),
)
)
],
direction="BACKWARD",
mask=field_mask_findpath,
)
)
print(fqswhid_of_traversal(response, verbose=trace))
except grpc.RpcError as e:
print("Error from the GRPC API call: {}".format(e.details()))
if filename:
print(filename + " has SWHID " + content_swhid)
except Exception:
logger.exception("Unexpected error occurred:")