Source code for swh.dataset.test.test_journal_processor

# Copyright (C) 2024  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import functools
import gc
import multiprocessing
from multiprocessing.managers import ListProxy
from typing import Any, Dict, Sequence, Tuple

import pytest

from swh.dataset.exporter import Exporter
from swh.dataset.journalprocessor import ParallelJournalProcessor
from swh.journal.serializers import kafka_to_value, value_to_kafka
from swh.journal.writer import get_journal_writer
from swh.model import model
from swh.model.tests import swh_model_data


[docs] @pytest.fixture def journal_client_config( kafka_server: str, kafka_prefix: str, kafka_consumer_group: str ): return dict( brokers=kafka_server, group_id=kafka_consumer_group, prefix=kafka_prefix, )
[docs] @pytest.fixture def journal_writer(kafka_server: str, kafka_prefix: str): return get_journal_writer( cls="kafka", brokers=[kafka_server], client_id="kafka_writer", prefix=kafka_prefix, anonymize=False, )
[docs] def disable_gc(f): """Decorator for test functions; prevents segfaults in confluent-kafka. See https://github.com/confluentinc/confluent-kafka-python/issues/1761""" @functools.wraps(f) def newf(*args, **kwargs): gc.disable() try: return f(*args, **kwargs) finally: gc.enable() return newf
[docs] class ListExporter(Exporter): def __init__(self, objects: ListProxy, *args, **kwargs): self._objects = objects super().__init__(*args, **kwargs)
[docs] def process_object( self, object_type: model.ModelObjectType, obj: Dict[str, Any] ) -> None: self._objects.append((object_type, obj))
[docs] def assert_exported_objects( exported_objects: Sequence[Tuple[str, Dict]], expected_objects: Sequence[model.BaseModel], ) -> None: def key(obj): """bare minimum to get a deterministic order""" return (obj[0],) + tuple( obj[1].get(k) for k in ("id", "url", "origin", "visit", "date") ) assert sorted(exported_objects, key=key) == sorted( ( ( obj.object_type, kafka_to_value(value_to_kafka(obj.to_dict())), ) # normalize for obj in expected_objects ), key=key, )
[docs] def test_parallel_journal_processor( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions("revision", swh_model_data.REVISIONS) with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "revision").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=set(), exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="revision", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects(objects, swh_model_data.REVISIONS)
[docs] def test_parallel_journal_processor_origin( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions("origin", swh_model_data.ORIGINS) with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "origin").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=set(), exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="origin", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects(objects, swh_model_data.ORIGINS)
[docs] @disable_gc def test_parallel_journal_processor_origin_visit_status( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions( "origin_visit_status", swh_model_data.ORIGIN_VISIT_STATUSES ) with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "origin_visit_status").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=set(), exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="origin_visit_status", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects(objects, swh_model_data.ORIGIN_VISIT_STATUSES)
[docs] @disable_gc def test_parallel_journal_processor_offsets( journal_client_config, journal_writer, tmp_path ) -> None: """Checks the exporter stops at the offsets computed at the beginning of the export""" journal_writer.write_additions("revision", swh_model_data.REVISIONS[0:2]) with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "revision").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=set(), exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="revision", node_sets_path=tmp_path / "node_sets", ) processor.get_offsets() # fills the processor.offsets cache processor.run() assert_exported_objects(objects, swh_model_data.REVISIONS[0:2])
[docs] @disable_gc def test_parallel_journal_processor_masked( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions("revision", swh_model_data.REVISIONS) masked_swhids = {swh_model_data.REVISIONS[2].swhid().to_extended()} with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "revision").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=masked_swhids, exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="revision", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects( objects, swh_model_data.REVISIONS[0:2] + swh_model_data.REVISIONS[3:] )
[docs] @disable_gc def test_parallel_journal_processor_masked_origin( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions("origin", swh_model_data.ORIGINS) masked_swhids = {swh_model_data.ORIGINS[1].swhid()} with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "origin").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=masked_swhids, exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="origin", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects( objects, swh_model_data.ORIGINS[0:1] + swh_model_data.ORIGINS[2:] )
[docs] @disable_gc def test_parallel_journal_processor_masked_origin_visit_statuses( journal_client_config, journal_writer, tmp_path ) -> None: journal_writer.write_additions( "origin_visit_status", swh_model_data.ORIGIN_VISIT_STATUSES ) masked_origin = model.Origin(url=swh_model_data.ORIGIN_VISIT_STATUSES[1].origin) masked_swhids = {masked_origin.swhid()} with multiprocessing.Manager() as manager: objects = manager.list() (tmp_path / "node_sets" / "origin_visit_status").mkdir(parents=True) export_path = tmp_path / "export" config = {"journal": journal_client_config} processor = ParallelJournalProcessor( config=config, masked_swhids=masked_swhids, exporter_factories=[ functools.partial( ListExporter, objects, config=config, export_path=export_path ) ], export_id="test_parallel_journal_processor", obj_type="origin_visit_status", node_sets_path=tmp_path / "node_sets", ) processor.run() assert_exported_objects( objects, [ ovs for ovs in swh_model_data.ORIGIN_VISIT_STATUSES if ovs.origin != masked_origin.url ], )