Source code for swh.graph.luigi.utils

# Copyright (C) 2022-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

# WARNING: do not import unnecessary things here to keep cli startup time under
# control
from pathlib import Path
from typing import Dict, List, Tuple

import luigi

from swh.export.luigi import AthenaDatabaseTarget

OBJECT_TYPES = {"ori", "snp", "rel", "rev", "dir", "cnt"}


# singleton written to signal to workers they should stop
class _EndOfQueue:
    pass


_ENF_OF_QUEUE = _EndOfQueue()


[docs] def estimate_node_count( local_graph_path: Path, graph_name: str, object_types: str ) -> int: """Returns the number of nodes of the given types (in the 'cnt,dir,rev,rel,snp,ori' format) in the graph. This is meant to estimate RAM usage, and will return an overapproximation if the actual number of nodes is not available yet. """ try: return count_nodes(local_graph_path, graph_name, object_types) except FileNotFoundError: # The graph was not compressed yet, so we can't estimate the memory it will take # to run this task, once the compressed graph is available. # As a hack, we return a very large value, which will force Luigi to stop before # running this, and when starting again (after the compressed graph is available), # it will call this function again to get an accurate estimate return 10**100
[docs] def count_nodes(local_graph_path: Path, graph_name: str, object_types: str) -> int: """Returns the number of nodes of the given types (in the 'cnt,dir,rev,rel,snp,ori' format) in the graph. """ node_stats = (local_graph_path / f"{graph_name}.nodes.stats.txt").read_text() nb_nodes_per_type = dict(line.split() for line in node_stats.split("\n") if line) return sum(int(nb_nodes_per_type[type_]) for type_ in object_types.split(","))
class _ParquetToS3ToAthenaTask(luigi.Task): """Base class for tasks which take a CSV as input, convert it to ORC, upload the ORC to S3, and create an Athena table for it.""" parallelism = 10 def _input_parquet_path(self) -> Path: raise NotImplementedError(f"{self.__class__.__name__}._input_parquet_path") def _s3_bucket(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._s3_bucket") def _s3_prefix(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._s3_prefix") def _parquet_columns(self) -> List[Tuple[str, str]]: """Returns a list of ``(column_name, parquet_type)``""" raise NotImplementedError(f"{self.__class__.__name__}._parquet_columns") def _approx_nb_rows(self) -> int: """Returns number of rows in the CSV file. Used only for progress reporting""" import pyarrow.parquet return sum( pyarrow.parquet.read_metadata(file).num_rows for file in self._input_parquet_path().iterdir() ) def _athena_db_name(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._athena_db_name") def _athena_table_name(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._athena_table_name") def output(self) -> luigi.Target: return AthenaDatabaseTarget(self._athena_db_name(), {self._athena_table_name()}) def run(self) -> None: """Copies all files: first the graph itself, then :file:`meta/compression.json`.""" import multiprocessing import tqdm self.__status_messages: Dict[Path, str] = {} paths = list(self._input_parquet_path().glob("**/*.parquet")) with multiprocessing.Pool(self.parallelism) as p: for i, relative_path in tqdm.tqdm( enumerate(p.imap_unordered(self._upload_file, paths)), total=len(paths), desc=f"Uploading {self._input_parquet_path()} to " f"s3://{self._s3_bucket()}/{self._s3_prefix()}/", ): self.set_progress_percentage(int(i * 100 / len(paths))) self.set_status_message("\n".join(self.__status_messages.values())) self._create_athena_table() def _upload_file(self, path): import luigi.contrib.s3 client = luigi.contrib.s3.S3Client() relative_path = path.relative_to(self._input_parquet_path()) self.__status_messages[path] = f"Uploading {relative_path}" client.put_multipart( path, f"s3://{self._s3_bucket()}/{self._s3_prefix()}/{relative_path}", ACL="public-read", ) del self.__status_messages[path] return relative_path def _create_athena_table(self): import boto3 from swh.export.athena import query client = boto3.client("athena") client.output_location = self.s3_athena_output_location client.database_name = "default" # we have to pick some existing database query( client, f"CREATE DATABASE IF NOT EXISTS {self._athena_db_name()};", desc=f"Creating {self._athena_db_name()} database", ) client.database_name = self._athena_db_name() columns = ", ".join( f"{col} {type_}" for (col, type_) in self._parquet_columns() ) query( client, f""" CREATE EXTERNAL TABLE IF NOT EXISTS {self._athena_db_name()}.{self._athena_table_name()} ({columns}) {self.create_table_extras()} STORED AS PARQUET LOCATION 's3://{self._s3_bucket()}/{self._s3_prefix()}'; """, desc=f"Creating table {self._athena_table_name()}", ) # needed for partitioned tables query( client, f"MSCK REPAIR TABLE `{self._athena_table_name()}`", desc=f"'Repairing' table {self._athena_table_name()}", ) def create_table_extras(self) -> str: """Extra clauses to add to the ``CREATE EXTERNAL TABLE`` statement.""" return ""