Source code for swh.counters.redis
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
from typing import Any, Dict, Iterable, List
from redis.client import Redis as RedisClient
from redis.exceptions import ConnectionError
DEFAULT_REDIS_PORT = 6379
logger = logging.getLogger(__name__)
[docs]
class Redis:
"""Redis based implementation of the counters.
It uses one HyperLogLog collection per counter"""
_redis_client = None
def __init__(self, host: str):
host_port = host.split(":")
if len(host_port) > 2:
raise ValueError("Invalid server url `%s`" % host)
self.host = host_port[0]
self.port = int(host_port[1]) if len(host_port) > 1 else DEFAULT_REDIS_PORT
@property
def redis_client(self) -> RedisClient:
if self._redis_client is None:
self._redis_client = RedisClient(host=self.host, port=self.port)
return self._redis_client
[docs]
def check(self):
try:
return self.redis_client.ping()
except ConnectionError:
logger.exception("Unable to connect to the redis server")
return False
[docs]
def add(self, collection: str, keys: Iterable[Any]) -> None:
redis = self.redis_client
pipeline = redis.pipeline(transaction=False)
[pipeline.pfadd(collection, key) for key in keys]
pipeline.execute()
[docs]
def get_count(self, collection: str) -> int:
return self.redis_client.pfcount(collection)
[docs]
def get_counts(self, collections: List[str]) -> Dict[str, int]:
return {coll: self.get_count(coll) for coll in collections}
[docs]
def get_counters(self) -> Iterable[str]:
return self.redis_client.keys()