# Copyright (C) 2015-2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from copy import deepcopy
from functools import lru_cache
from itertools import chain
import logging
import os
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from backports.entry_points_selectable import entry_points as get_entry_points
from deprecated import deprecated
import yaml
logger = logging.getLogger(__name__)
SWH_CONFIG_DIRECTORIES = [
"~/.config/swh",
"~/.swh",
"/etc/softwareheritage",
]
SWH_GLOBAL_CONFIG = "global.yml"
SWH_DEFAULT_GLOBAL_CONFIG = {
"max_content_size": ("int", 100 * 1024 * 1024),
}
SWH_CONFIG_EXTENSIONS = [
".yml",
]
# conversion per type
_map_convert_fn: Dict[str, Callable] = {
"int": int,
"bool": lambda x: x.lower() == "true",
"list[str]": lambda x: [value.strip() for value in x.split(",")],
"list[int]": lambda x: [int(value.strip()) for value in x.split(",")],
}
_map_check_fn: Dict[str, Callable] = {
"int": lambda x: isinstance(x, int),
"bool": lambda x: isinstance(x, bool),
"list[str]": lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)),
"list[int]": lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)),
}
[docs]
def exists_accessible(filepath: str) -> bool:
"""Check whether a file exists, and is accessible.
Returns:
True if the file exists and is accessible
False if the file does not exist
Raises:
PermissionError if the file cannot be read.
"""
try:
os.stat(filepath)
except PermissionError:
raise
except (FileNotFoundError, NotADirectoryError):
return False
else:
if os.access(filepath, os.R_OK):
return True
else:
raise PermissionError("Permission denied: {filepath!r}")
[docs]
def read_raw_config(base_config_path: str) -> Dict[str, Any]:
"""Read the raw config corresponding to base_config_path.
Can read yml files.
"""
yml_file = config_path(base_config_path)
if yml_file is None:
logging.error("Config file %s does not exist, ignoring it.", base_config_path)
return {}
else:
logger.debug("Loading config file %s", yml_file)
with open(yml_file) as f:
return yaml.safe_load(f)
[docs]
@deprecated(
version="2.23.0",
reason="pass config paths as-is to read_raw_config/read, and rely on click.Path",
)
def config_exists(path):
"""Check whether the given config exists"""
path = config_path(path)
return path is not None and exists_accessible(path)
[docs]
@deprecated(version="2.23.0", reason="pass config paths as-is to read_raw_config/read")
def config_basepath(config_path: str) -> str:
"""Return the base path of a configuration file"""
if config_path.endswith(".yml"):
return config_path[:-4]
return config_path
[docs]
def config_path(config_path):
"""Check whether the given config exists"""
if exists_accessible(config_path):
return config_path
for extension in SWH_CONFIG_EXTENSIONS:
if exists_accessible(config_path + extension):
logger.warning(
"%s does not exist, using %s instead",
config_path,
config_path + extension,
)
return config_path + extension
return None
[docs]
def read(
conf_file: Optional[str] = None,
default_conf: Optional[Dict[str, Tuple[str, Any]]] = None,
) -> Dict[str, Any]:
"""Read the user's configuration file.
Fill in the gap using `default_conf`. `default_conf` is similar to this::
DEFAULT_CONF = {
'a': ('str', '/tmp/swh-loader-git/log'),
'b': ('str', 'dbname=swhloadergit')
'c': ('bool', true)
'e': ('bool', None)
'd': ('int', 10)
}
If conf_file is None, return the default config.
"""
conf: Dict[str, Any] = {}
if conf_file:
base_config_path = os.path.expanduser(conf_file)
conf = read_raw_config(base_config_path) or {}
if not default_conf:
return conf
# remaining missing default configuration key are set
# also type conversion is enforced for underneath layer
for key, (nature_type, default_value) in default_conf.items():
val = conf.get(key, None)
if val is None: # fallback to default value
conf[key] = default_value
elif not _map_check_fn.get(nature_type, lambda x: True)(val):
# value present but not in the proper format, force type conversion
conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val)
return conf
[docs]
def priority_read(
conf_filenames: List[str], default_conf: Optional[Dict[str, Tuple[str, Any]]] = None
):
"""Try reading the configuration files from conf_filenames, in order,
and return the configuration from the first one that exists.
default_conf has the same specification as it does in read.
"""
# Try all the files in order
for filename in conf_filenames:
full_filename = config_path(os.path.expanduser(filename))
if full_filename is not None:
return read(full_filename, default_conf)
# Else, return the default configuration
return read(None, default_conf)
[docs]
def merge_default_configs(base_config, *other_configs):
"""Merge several default config dictionaries, from left to right"""
full_config = base_config.copy()
for config in other_configs:
full_config.update(config)
return full_config
[docs]
def merge_configs(base: Optional[Dict[str, Any]], other: Optional[Dict[str, Any]]):
"""Merge two config dictionaries
This does merge config dicts recursively, with the rules, for every value
of the dicts (with 'val' not being a dict):
- None + type -> type
- type + None -> None
- dict + dict -> dict (merged)
- val + dict -> TypeError
- dict + val -> TypeError
- val + val -> val (other)
for instance:
>>> d1 = {
... 'key1': {
... 'skey1': 'value1',
... 'skey2': {'sskey1': 'value2'},
... },
... 'key2': 'value3',
... }
with
>>> d2 = {
... 'key1': {
... 'skey1': 'value4',
... 'skey2': {'sskey2': 'value5'},
... },
... 'key3': 'value6',
... }
will give:
>>> d3 = {
... 'key1': {
... 'skey1': 'value4', # <-- note this
... 'skey2': {
... 'sskey1': 'value2',
... 'sskey2': 'value5',
... },
... },
... 'key2': 'value3',
... 'key3': 'value6',
... }
>>> assert merge_configs(d1, d2) == d3
Note that no type checking is done for anything but dicts.
"""
if not isinstance(base, dict) or not isinstance(other, dict):
raise TypeError("Cannot merge a %s with a %s" % (type(base), type(other)))
output = {}
for k in chain(base.keys(), other.keys()):
if k in output:
continue
vb = base.get(k)
vo = other.get(k)
if isinstance(vo, dict):
output[k] = merge_configs(vb is not None and vb or {}, vo)
elif isinstance(vb, dict) and k in other and other[k] is not None:
output[k] = merge_configs(vb, vo is not None and vo or {})
elif k in other:
output[k] = deepcopy(vo)
else:
output[k] = deepcopy(vb)
return output
[docs]
def swh_config_paths(base_filename: str) -> List[str]:
"""Return the Software Heritage specific configuration paths for the given
filename."""
return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES]
[docs]
def prepare_folders(conf, *keys):
"""Prepare the folder mentioned in config under keys."""
def makedir(folder):
if not os.path.exists(folder):
os.makedirs(folder)
for key in keys:
makedir(conf[key])
[docs]
def load_global_config():
"""Load the global Software Heritage config"""
return priority_read(
swh_config_paths(SWH_GLOBAL_CONFIG),
SWH_DEFAULT_GLOBAL_CONFIG,
)
[docs]
def load_named_config(name, default_conf=None, global_conf=True):
"""Load the config named `name` from the Software Heritage
configuration paths.
If global_conf is True (default), read the global configuration
too.
"""
conf = {}
if global_conf:
conf.update(load_global_config())
conf.update(priority_read(swh_config_paths(name), default_conf))
return conf
[docs]
def load_from_envvar(default_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Load configuration yaml file from the environment variable SWH_CONFIG_FILENAME,
eventually enriched with default configuration key/value from the default_config
dict if provided.
Returns:
Configuration dict
Raises:
AssertionError if SWH_CONFIG_FILENAME is undefined
"""
assert (
"SWH_CONFIG_FILENAME" in os.environ
), "SWH_CONFIG_FILENAME environment variable is undefined."
cfg_path = os.environ["SWH_CONFIG_FILENAME"]
cfg = read_raw_config(cfg_path)
cfg = merge_configs(default_config or {}, cfg)
return cfg
[docs]
@lru_cache()
def get_swh_backend_module(swh_package: str, cls: str) -> Tuple[str, Optional[type]]:
entry_points = get_entry_points(group=f"swh.{swh_package}.classes")
if not entry_points:
# it's an "old-style" swh package, not declaring its classes entry point
logger.warning(
f"swh package does not yet declare the swh.{swh_package}.classes "
"endpoint. Make sure all your swh dependencies are up to date."
)
if not swh_package.startswith("swh."):
swh_package = f"swh.{swh_package}"
return swh_package, None
try:
entry_point = entry_points[cls]
except KeyError:
raise ValueError(
"Unknown %s class `%s`. Supported: %s"
% (
swh_package,
cls,
", ".join(entry_point.name for entry_point in entry_points),
)
) from None
BackendCls = entry_point.load()
return entry_point.module, BackendCls
[docs]
@lru_cache()
def get_swh_backend_from_fullmodule(
fullmodule: str,
) -> Tuple[Optional[str], Optional[str]]:
if not fullmodule.startswith("swh."):
fullmodule = f"swh.{fullmodule}"
package = fullmodule.split(".")[1]
entry_points = get_entry_points(group=f"swh.{package}.classes")
for entry_point in entry_points:
if entry_point.module == fullmodule:
return package, entry_point.name
return None, None
[docs]
def list_swh_backends(package: str) -> List[str]:
if package.startswith("swh."):
package = package[4:]
entry_points = get_entry_points(group=f"swh.{package}.classes")
return [ep.name for ep in entry_points]
[docs]
def list_db_config_entries(cfg) -> Generator[Tuple[str, str, dict, str], None, None]:
"""List all the db config entries in the given config structure
Generates quadruplets (module, path, cfg, cnxstr) where:
- the swh module name (aka top level config entries, eg. 'storage',
'scheduler', etc.)
- path: the path within the config structure of the (sub)config entry in
which the db connection has been found,
- cfg: the config subentry from the given gcfg in which the db config has
been found; it contains at least a 'cls' key,
- db: the db connection string
"""
def look(cfg, path):
if "cls" in cfg:
for key, value in cfg.items():
if key == "db" or key.endswith("_db"):
yield path, cfg, value
elif isinstance(value, list):
for i, subcfg in enumerate(value):
yield from look(subcfg, path=f"{path}.{key}.{i}")
elif isinstance(value, dict):
yield from look(value, path=f"{path}.{key}")
for rootmodule, subcfg in cfg.items():
for path, cfg_entry, cnxstr in look(subcfg, rootmodule):
yield rootmodule, path, cfg_entry, cnxstr