# Copyright (C) 2022-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

# WARNING: do not import unnecessary things here to keep cli startup time under
# control
from pathlib import Path
from typing import List, Tuple

import luigi

from swh.dataset.luigi import AthenaDatabaseTarget

OBJECT_TYPES = {"ori", "snp", "rel", "rev", "dir", "cnt"}

# singleton written to signal to workers they should stop
class _EndOfQueue:

_ENF_OF_QUEUE = _EndOfQueue()

[docs] def count_nodes(local_graph_path: Path, graph_name: str, object_types: str) -> int: """Returns the number of nodes of the given types (in the 'cnt,dir,rev,rel,snp,ori' format) in the graph. """ node_stats = (local_graph_path / f"{graph_name}.nodes.stats.txt").read_text() nb_nodes_per_type = dict(line.split() for line in node_stats.split("\n") if line) return sum(int(nb_nodes_per_type[type_]) for type_ in object_types.split(","))
class _ParquetToS3ToAthenaTask(luigi.Task): """Base class for tasks which take a CSV as input, convert it to ORC, upload the ORC to S3, and create an Athena table for it.""" parallelism = 10 def _input_parquet_path(self) -> Path: raise NotImplementedError(f"{self.__class__.__name__}._input_parquet_path") def _s3_bucket(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._s3_bucket") def _s3_prefix(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._s3_prefix") def _parquet_columns(self) -> List[Tuple[str, str]]: """Returns a list of ``(column_name, orc_type)``""" raise NotImplementedError(f"{self.__class__.__name__}._parquet_columns") def _approx_nb_rows(self) -> int: """Returns number of rows in the CSV file. Used only for progress reporting""" import pyarrow.parquet return sum( pyarrow.parquet.read_metadata(file).num_rows for file in self._input_parquet_path().iterdir() ) def _athena_db_name(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._athena_db_name") def _athena_table_name(self) -> str: raise NotImplementedError(f"{self.__class__.__name__}._athena_table_name") def output(self) -> luigi.Target: return AthenaDatabaseTarget(self._athena_db_name(), {self._athena_table_name()}) def run(self) -> None: """Copies all files: first the graph itself, then :file:`meta/compression.json`.""" import multiprocessing.dummy import tqdm paths = list(self._input_parquet_path().glob("*.parquet")) with multiprocessing.dummy.Pool(self.parallelism) as p: for i, relative_path in tqdm.tqdm( enumerate(p.imap_unordered(self._upload_file, paths)), total=len(paths), desc="Uploading compressed graph", ): self.set_progress_percentage(int(i * 100 / len(paths))) self.set_status_message("\n".join(self.__status_messages.values())) self._create_athena_table() def _upload_file(self, path): import luigi.contrib.s3 client = luigi.contrib.s3.S3Client() relative_path = path.relative_to(self._input_parquet_path()) self.__status_messages[path] = f"Uploading {relative_path}" client.put_multipart( path, f"s3://{self._s3_bucket()}/{self._s3_prefix()}/{relative_path}", ACL="public-read", ) del self.__status_messages[path] return relative_path def _create_athena_table(self): import boto3 from swh.dataset.athena import query client = boto3.client("athena") client.output_location = self.s3_athena_output_location client.database_name = "default" # we have to pick some existing database query( client, f"CREATE DATABASE IF NOT EXISTS {self._athena_db_name()};", desc=f"Creating {self._athena_db_name()} database", ) client.database_name = self._athena_db_name() columns = ", ".join(f"{col} {type_}" for (col, type_) in self._orc_columns()) query( client, f""" CREATE EXTERNAL TABLE IF NOT EXISTS {self._athena_db_name()}.{self._athena_table_name()} ({columns}) STORED AS PARQUET LOCATION 's3://{self._s3_bucket()}/{self._s3_prefix()}'; """, desc="Creating table {self._athena_table_name()}", )