Source code for swh.journal.client

# Copyright (C) 2017-2023  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from collections import defaultdict
import enum
from importlib import import_module
from itertools import cycle
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import warnings

from confluent_kafka import (
    OFFSET_BEGINNING,
    Consumer,
    KafkaError,
    KafkaException,
    TopicPartition,
)

from swh.core.statsd import Statsd
from swh.journal import DEFAULT_PREFIX

from .serializers import kafka_to_value

logger = logging.getLogger(__name__)
rdkafka_logger = logging.getLogger(__name__ + ".rdkafka")


# Only accepted offset reset policy accepted
ACCEPTED_OFFSET_RESET = ["earliest", "latest"]

# Errors that Kafka raises too often and are not useful; therefore they
# we lower their log level to DEBUG instead of INFO.
_SPAMMY_ERRORS = [
    KafkaError._NO_OFFSET,
]


[docs]class EofBehavior(enum.Enum): """Possible behaviors when reaching the end of the log""" CONTINUE = "continue" STOP = "stop" RESTART = "restart"
[docs]def get_journal_client(cls: str, **kwargs: Any): """Factory function to instantiate a journal client object. Currently, only the "kafka" journal client is supported. """ if cls == "kafka": if "stats_cb" in kwargs: stats_cb = kwargs["stats_cb"] if isinstance(stats_cb, str): try: module_path, func_name = stats_cb.split(":") except ValueError: raise ValueError( "Invalid stats_cb configuration option: " "it should be a string like 'path.to.module:function'" ) try: module = import_module(module_path, package=__package__) except ModuleNotFoundError: raise ValueError( "Invalid stats_cb configuration option: " f"module {module_path} not found" ) try: kwargs["stats_cb"] = getattr(module, func_name) except AttributeError: raise ValueError( "Invalid stats_cb configuration option: " f"function {func_name} not found in module {module_path}" ) return JournalClient(**kwargs) raise ValueError("Unknown journal client class `%s`" % cls)
def _error_cb(error): if error.fatal(): raise KafkaException(error) if error.code() in _SPAMMY_ERRORS: logger.debug("Received non-fatal kafka error: %s", error) else: logger.info("Received non-fatal kafka error: %s", error) def _on_commit(error, partitions): if error is not None: _error_cb(error)
[docs]class JournalClient: """A base client for the Software Heritage journal. The current implementation of the journal uses Apache Kafka brokers to publish messages under a given topic prefix, with each object type using a specific topic under that prefix. If the `prefix` argument is None (default value), it will take the default value `'swh.journal.objects'`. Clients subscribe to events specific to each object type as listed in the ``object_types`` argument (if unset, defaults to all existing kafka topic under the prefix). Clients can be sharded by setting the ``group_id`` to a common value across instances. The journal will share the message throughput across the nodes sharing the same group_id. Messages are processed by the ``worker_fn`` callback passed to the `process` method, in batches of maximum ``batch_size`` messages (defaults to 200). The objects passed to the ``worker_fn`` callback are the result of the kafka message converted by the ``value_deserializer`` function. By default (if this argument is not given), it will produce dicts (using the ``kafka_to_value`` function). This signature of the function is:: value_deserializer(object_type: str, kafka_msg: bytes) -> Any If the value returned by ``value_deserializer`` is None, it is ignored and not passed the ``worker_fn`` function. Arguments: stop_after_objects: If set, the processing stops after processing this number of messages in total. on_eof: What to do when reaching the end of each partition (keep consuming, stop, or restart from earliest offsets); defaults to continuing. This can be either a :class:`EofBehavior` variant or a string containing the name of one of the variants. stop_on_eof: (deprecated) equivalent to passing ``on_eof=EofBehavior.STOP`` auto_offset_reset: sets the behavior of the client when the consumer group initializes: ``'earliest'`` (the default) processes all objects since the inception of the topics; ``''`` Any other named argument is passed directly to KafkaConsumer(). """ def __init__( self, brokers: Union[str, List[str]], group_id: str, prefix: Optional[str] = None, object_types: Optional[List[str]] = None, privileged: bool = False, stop_after_objects: Optional[int] = None, batch_size: int = 200, process_timeout: Optional[float] = None, auto_offset_reset: str = "earliest", stop_on_eof: Optional[bool] = None, on_eof: Optional[Union[EofBehavior, str]] = None, value_deserializer: Optional[Callable[[str, bytes], Any]] = None, **kwargs, ): if prefix is None: prefix = DEFAULT_PREFIX if auto_offset_reset not in ACCEPTED_OFFSET_RESET: raise ValueError( "Option 'auto_offset_reset' only accept %s, not %s" % (ACCEPTED_OFFSET_RESET, auto_offset_reset) ) if batch_size <= 0: raise ValueError("Option 'batch_size' needs to be positive") if value_deserializer: self.value_deserializer = value_deserializer else: self.value_deserializer = lambda _, value: kafka_to_value(value) if stop_on_eof is not None: if on_eof is not None: raise TypeError( "stop_on_eof and on_eof are mutually exclusive (the former is " "deprecated)" ) elif stop_on_eof: warnings.warn( "stop_on_eof=True should be replaced with " "on_eof=EofBehavior.STOP ('on_eof: stop' in YAML)", DeprecationWarning, 2, ) on_eof = EofBehavior.STOP else: warnings.warn( "stop_on_eof=False should be replaced with " "on_eof=EofBehavior.CONTINUE ('on_eof: continue' in YAML)", DeprecationWarning, 2, ) on_eof = EofBehavior.CONTINUE self.on_eof = EofBehavior(on_eof or EofBehavior.CONTINUE) if isinstance(brokers, str): brokers = [brokers] debug_logging = rdkafka_logger.isEnabledFor(logging.DEBUG) if debug_logging and "debug" not in kwargs: kwargs["debug"] = "consumer" # Static group instance id management group_instance_id = os.environ.get("KAFKA_GROUP_INSTANCE_ID") if group_instance_id: kwargs["group.instance.id"] = group_instance_id if "group.instance.id" in kwargs: # When doing static consumer group membership, set a higher default # session timeout. The session timeout is the duration after which # the broker considers that a consumer has left the consumer group # for good, and triggers a rebalance. Considering our current # processing pattern, 10 minutes gives the consumer ample time to # restart before that happens. if "session.timeout.ms" not in kwargs: kwargs["session.timeout.ms"] = 10 * 60 * 1000 # 10 minutes if "session.timeout.ms" in kwargs: # When the session timeout is set, rdkafka requires the max poll # interval to be set to a higher value; the max poll interval is # rdkafka's way of figuring out whether the client's message # processing thread has stalled: when the max poll interval lapses # between two calls to consumer.poll(), rdkafka leaves the consumer # group and terminates the connection to the brokers. # # We default to 1.5 times the session timeout if "max.poll.interval.ms" not in kwargs: kwargs["max.poll.interval.ms"] = kwargs["session.timeout.ms"] // 2 * 3 consumer_settings = { **kwargs, "bootstrap.servers": ",".join(brokers), "auto.offset.reset": auto_offset_reset, "group.id": group_id, "on_commit": _on_commit, "error_cb": _error_cb, "enable.auto.commit": False, "logger": rdkafka_logger, } if self.on_eof != EofBehavior.CONTINUE: consumer_settings["enable.partition.eof"] = True if logger.isEnabledFor(logging.DEBUG): filtered_keys = {"sasl.password"} logger.debug("Consumer settings:") for k, v in consumer_settings.items(): if k in filtered_keys: v = "**filtered**" logger.debug(" %s: %s", k, v) self.statsd = Statsd(namespace="swh_journal_client") self.consumer = Consumer(consumer_settings) if privileged: privileged_prefix = f"{prefix}_privileged" else: # do not attempt to subscribe to privileged topics privileged_prefix = f"{prefix}" existing_topics = [ topic for topic in self.consumer.list_topics(timeout=10).topics.keys() if ( topic.startswith(f"{prefix}.") or topic.startswith(f"{privileged_prefix}.") ) ] if not existing_topics: raise ValueError( f"The prefix {prefix} does not match any existing topic " "on the kafka broker" ) if not object_types: object_types = list({topic.split(".")[-1] for topic in existing_topics}) self.subscription = [] unknown_types = [] for object_type in object_types: topics = (f"{privileged_prefix}.{object_type}", f"{prefix}.{object_type}") for topic in topics: if topic in existing_topics: self.subscription.append(topic) break else: unknown_types.append(object_type) if unknown_types: raise ValueError( f"Topic(s) for object types {','.join(unknown_types)} " "are unknown on the kafka broker" ) logger.debug(f"Upstream topics: {existing_topics}") self.subscribe() self.stop_after_objects = stop_after_objects self.eof_reached: Set[Tuple[str, str]] = set() self.batch_size = batch_size if process_timeout is not None: raise DeprecationWarning( "'process_timeout' argument is not supported anymore by " "JournalClient; please remove it from your configuration.", )
[docs] def subscribe(self): """Subscribe to topics listed in self.subscription This can be overridden if you need, for instance, to manually assign partitions. """ logger.debug(f"Subscribing to: {self.subscription}") self.consumer.subscribe(topics=self.subscription)
[docs] def process(self, worker_fn: Callable[[Dict[str, List[dict]]], None]): """Polls Kafka for a batch of messages, and calls the worker_fn with these messages. Args: worker_fn: Function called with the messages as argument. """ total_objects_processed = 0 # timeout for message poll timeout = 1.0 with self.statsd.status_gauge( "status", statuses=["idle", "processing", "waiting"] ) as set_status: set_status("idle") while True: batch_size = self.batch_size if self.stop_after_objects: if total_objects_processed >= self.stop_after_objects: break # clamp batch size to avoid overrunning stop_after_objects batch_size = min( self.stop_after_objects - total_objects_processed, batch_size, ) set_status("waiting") for i in cycle(reversed(range(10))): messages = self.consumer.consume( timeout=timeout, num_messages=batch_size ) if messages: break # do check for an EOF condition iff we already consumed # messages, otherwise we could detect an EOF condition # before messages had a chance to reach us (e.g. in tests) if total_objects_processed > 0 and i == 0: if self.on_eof == EofBehavior.STOP: at_eof = all( (tp.topic, tp.partition) in self.eof_reached for tp in self.consumer.assignment() ) if at_eof: break elif self.on_eof == EofBehavior.RESTART: for tp in self.consumer.assignment(): if (tp.topic, tp.partition) in self.eof_reached: self.eof_reached.remove((tp.topic, tp.partition)) self.statsd.increment("partition_restart_total") new_tp = TopicPartition( tp.topic, tp.partition, OFFSET_BEGINNING, ) self.consumer.seek(new_tp) elif self.on_eof == EofBehavior.CONTINUE: pass # Nothing to do, we'll just keep consuming else: assert False, f"Unexpected on_eof behavior: {self.on_eof}" if messages: set_status("processing") batch_processed, at_eof = self.handle_messages(messages, worker_fn) set_status("idle") # report the number of handled messages self.statsd.increment("handle_message_total", value=batch_processed) total_objects_processed += batch_processed if self.on_eof == EofBehavior.STOP and at_eof: self.statsd.increment("stop_total") break return total_objects_processed
[docs] def handle_messages( self, messages, worker_fn: Callable[[Dict[str, List[dict]]], None] ) -> Tuple[int, bool]: objects: Dict[str, List[Any]] = defaultdict(list) nb_processed = 0 for message in messages: error = message.error() if error is not None: if error.code() == KafkaError._PARTITION_EOF: self.eof_reached.add((message.topic(), message.partition())) else: _error_cb(error) continue if message.value() is None: # ignore message with no payload, these can be generated in tests continue nb_processed += 1 object_type = message.topic().split(".")[-1] deserialized_object = self.deserialize_message( message, object_type=object_type ) if deserialized_object is not None: objects[object_type].append(deserialized_object) if objects: worker_fn(dict(objects)) self.consumer.commit() if self.on_eof in (EofBehavior.STOP, EofBehavior.RESTART): at_eof = all( (tp.topic, tp.partition) in self.eof_reached for tp in self.consumer.assignment() ) elif self.on_eof == EofBehavior.CONTINUE: at_eof = False else: assert False, f"Unexpected on_eof behavior: {self.on_eof}" return nb_processed, at_eof
[docs] def deserialize_message(self, message, object_type=None): return self.value_deserializer(object_type, message.value())
[docs] def close(self): self.consumer.close()