Source code for swh.scanner.client
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""
Minimal async web client for the Software Heritage Web API.
This module could be removed when
`T2635 <https://forge.softwareheritage.org/T2635>` is implemented.
"""
import asyncio
import itertools
import logging
import time
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
from swh.model.swhids import CoreSWHID
from .exceptions import error_response
logger = logging.getLogger(__name__)
# Maximum number of SWHIDs that can be requested by a single call to the
# Web API endpoint /known/
QUERY_LIMIT = 1000
MAX_RETRY = 10
KNOWN_EP = "known/"
GRAPH_RANDOMWALK_EP = "graph/randomwalk/"
def _get_chunk(swhids):
"""slice a list of `swhids` into smaller list of size QUERY_LIMIT"""
for i in range(0, len(swhids), QUERY_LIMIT):
yield swhids[i : i + QUERY_LIMIT]
def _parse_limit_header(response) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""parse the X-RateLimit Headers if any"""
limit = response.headers.get("X-RateLimit-Limit")
if limit is not None:
limit = int(limit)
remaining = response.headers.get("X-RateLimit-Remaining")
if remaining is not None:
remaining = int(remaining)
reset = response.headers.get("X-RateLimit-Reset")
if reset is not None:
reset = int(reset)
return (limit, remaining, reset)
[docs]
class Client:
"""Manage requests to the Software Heritage Web API."""
def __init__(
self,
api_url: str,
session: aiohttp.ClientSession,
):
self._sleep = 0
self.api_url = api_url
self.session = session
self._known_endpoint = self.api_url + KNOWN_EP
[docs]
async def get_origin(self, swhid: CoreSWHID) -> Optional[Any]:
"""Walk the compressed graph to discover the origin of a given swhid"""
endpoint = (
f"{self.api_url}{GRAPH_RANDOMWALK_EP}{str(swhid)}/ori/?direction="
f"backward&limit=-1&resolve_origins=true"
)
res = None
async with self.session.get(endpoint) as resp:
if resp.status == 200:
res = await resp.text()
res = res.rstrip()
return res
if resp.status != 404:
error_response(resp.reason, resp.status, endpoint)
return res
[docs]
async def known(self, swhids: List[CoreSWHID]) -> Dict[str, Dict[str, bool]]:
"""API Request to get information about the SoftWare Heritage persistent
IDentifiers (SWHIDs) given in input.
Args:
swhids: a list of CoreSWHID instances
api_url: url for the API request
Returns:
A dictionary with:
key:
string SWHID searched
value:
value['known'] = True if the SWHID is found
value['known'] = False if the SWHID is not found
"""
requests = []
swh_ids = [str(swhid) for swhid in swhids]
if len(swhids) <= QUERY_LIMIT:
return await self._make_request(swh_ids)
else:
for swhids_chunk in _get_chunk(swh_ids):
task = asyncio.create_task(self._make_request(swhids_chunk))
requests.append(task)
res = await asyncio.gather(*requests)
# concatenate list of dictionaries
return dict(itertools.chain.from_iterable(e.items() for e in res))
def _mark_success(self, limit=None, remaining=None, reset=None):
"""call when a request is successfully made, this will adjust the rate
The extra argument can be used to transmit the X-RateLimit information
from the server. This will be used to adjust the request rate"""
is_dbg = logger.isEnabledFor(logging.DEBUG)
self._sleep = 0
factor = 0
current = time.time()
if is_dbg:
dbg_msg = f"HTTP GOOD {current:.2f}:"
if limit is None or remaining is None or reset is None:
if is_dbg:
dbg_msg += " no rate limit data;"
else:
time_windows = reset - current
if is_dbg:
dbg_msg += f" requests={remaining}/{limit}"
dbg_msg += f" reset-in={time_windows:.2f}"
if time_windows > 0:
used_up = remaining / limit
if remaining <= 0:
# no more credit, we can sit up and wait.
#
# XXX we should warn the user. This can get very long.
self._sleep = time_windows
factor = -1
elif 0.6 < used_up:
# let us not limit the first flight of request.
factor = 0
else:
# the deeper we consume the credit the higher is the rate
# limiting, let's put a brake on our current rate the lower we get
#
# (The factor range from 1 to 1000)
factor = (0.4 + used_up) ** -1.5
if factor >= 0:
self._sleep = ((time_windows / remaining)) * factor
if is_dbg:
dbg_msg += f"; sleep={self._sleep:.3f}"
logger.debug(dbg_msg)
def _mark_failure(self, limit=None, remaining=None, reset=None):
"""call when a request failed, this will reduce the request rate.
The extra argument can be used to transmit the X-RateLimit information
from the server. This will be used to adjust the request rate"""
is_dbg = logger.isEnabledFor(logging.DEBUG)
current = time.time()
if is_dbg:
dbg_msg = f"HTTP BAD {current:.2f}:"
time_set = False
if remaining is None or reset is None:
if is_dbg:
dbg_msg += " no rate limit data"
else:
wait_for = reset - current
if is_dbg:
dbg_msg += f" requests={remaining}/{limit}"
dbg_msg += f" reset-in={wait_for:.2f}"
if remaining <= 0:
# Add some margin to please the rate limiting code
wait_for *= 1.1
if wait_for > 0 and wait_for >= self._sleep:
self._sleep = wait_for
time_set = True
if not time_set:
if self._sleep <= 0:
self._sleep = 1
else:
self._sleep *= 2
if is_dbg:
dbg_msg += "; sleep={self._sleep:.3f}"
logger.debug(dbg_msg)
async def _make_request(self, swhids):
endpoint = self._known_endpoint
data = None
retry = MAX_RETRY
while data is None:
# slow the pace of request if needed
if self._sleep > 0:
time.sleep(self._sleep)
async with self.session.post(endpoint, json=swhids) as resp:
rate_limit = _parse_limit_header(resp)
if resp.status == 200:
try:
# inform of success before the await
self._mark_success(*rate_limit)
data = await resp.json()
except aiohttp.client_exceptions.ClientConnectionError:
raise
else:
break
self._mark_failure(*rate_limit)
retry -= 1
if retry <= 0 or resp.status == 413: # 413: Payload Too Large
error_response(resp.reason, resp.status, endpoint)
return data