Source code for swh.counters.redis

# Copyright (C) 2021  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import logging
from typing import Any, Dict, Iterable, List

from redis.client import Redis as RedisClient
from redis.exceptions import ConnectionError

DEFAULT_REDIS_PORT = 6379


logger = logging.getLogger(__name__)


[docs] class Redis: """Redis based implementation of the counters. It uses one HyperLogLog collection per counter""" _redis_client = None def __init__(self, host: str): host_port = host.split(":") if len(host_port) > 2: raise ValueError("Invalid server url `%s`" % host) self.host = host_port[0] self.port = int(host_port[1]) if len(host_port) > 1 else DEFAULT_REDIS_PORT @property def redis_client(self) -> RedisClient: if self._redis_client is None: self._redis_client = RedisClient(host=self.host, port=self.port) return self._redis_client
[docs] def check(self): try: return self.redis_client.ping() except ConnectionError: logger.exception("Unable to connect to the redis server") return False
[docs] def add(self, collection: str, keys: Iterable[Any]) -> None: redis = self.redis_client pipeline = redis.pipeline(transaction=False) [pipeline.pfadd(collection, key) for key in keys] pipeline.execute()
[docs] def get_count(self, collection: str) -> int: return self.redis_client.pfcount(collection)
[docs] def get_counts(self, collections: List[str]) -> Dict[str, int]: return {coll: self.get_count(coll) for coll in collections}
[docs] def get_counters(self) -> Iterable[str]: return self.redis_client.keys()