# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import json
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from uuid import UUID
from uuid import uuid4 as uuid
from swh.scheduler.model import (
Task,
TaskPolicy,
TaskPriority,
TaskRun,
TaskRunStatus,
TaskStatus,
TaskType,
)
from swh.scheduler.utils import utcnow
from .exc import SchedulerException, StaleData, UnknownPolicy
from .interface import ListedOriginPageToken, PaginatedListedOriginList
from .model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics
logger = logging.getLogger(__name__)
epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
[docs]
class InMemoryScheduler:
def __init__(self):
self._task_types = set()
self._listers = []
self._listed_origins = []
self._origin_visit_stats = {}
self._tasks = []
self._task_runs = []
self._visit_scheduler_queue_position = {}
self._scheduler_metrics = {}
[docs]
def create_task_type(self, task_type: TaskType) -> None:
self._task_types.add(task_type)
[docs]
def get_task_type(self, task_type_name: str) -> Optional[TaskType]:
sel = [tt for tt in self._task_types if tt.type == task_type_name]
if sel:
return sel[0]
return None
[docs]
def get_task_types(self) -> List[TaskType]:
return list(self._task_types)
[docs]
def get_listers(
self,
with_first_visits_to_schedule: bool = False,
) -> List[Lister]:
"""Retrieve information about all listers from the database."""
listers = self._listers
if with_first_visits_to_schedule:
listers = [
_l
for _l in listers
if _l.last_listing_finished_at is not None
and _l.first_visits_queue_prefix is not None
and _l.first_visits_scheduled_at is None
]
return listers
[docs]
def get_listers_by_id(
self,
lister_ids: List[str],
) -> List[Lister]:
return [_l for _l in self._listers if _l.id in lister_ids]
[docs]
def get_lister(
self,
name: str,
instance_name: Optional[str] = None,
) -> Optional[Lister]:
if instance_name is None:
instance_name = ""
listers = [
_l
for _l in self._listers
if _l.name == name and _l.instance_name == instance_name
]
return listers and listers[0] or None
[docs]
def get_or_create_lister(
self,
name: str,
instance_name: Optional[str] = None,
first_visits_queue_prefix: Optional[str] = None,
) -> Lister:
if instance_name is None:
instance_name = ""
if self.get_lister(name, instance_name) is None:
self._listers.append(
Lister(
id=uuid(),
name=name,
instance_name=instance_name,
first_visits_queue_prefix=first_visits_queue_prefix,
updated=utcnow(),
)
)
lister = self.get_lister(name, instance_name)
assert lister is not None
return lister
[docs]
def update_lister(self, lister: Lister) -> Lister:
lids = [
i
for i, l in enumerate(self._listers)
if l.id == lister.id and l.updated == lister.updated
]
if lids:
lid = lids[0]
del self._listers[lid]
self._listers.append(lister.evolve(updated=utcnow()))
return self._listers[-1]
raise StaleData("Stale data; Lister state not updated")
[docs]
def record_listed_origins(
self,
listed_origins: Iterable[ListedOrigin],
) -> List[ListedOrigin]:
pk_cols = ListedOrigin.primary_key_columns()
deduplicated_origins = {
tuple(getattr(origin, k) for k in pk_cols): origin
for origin in listed_origins
}
all_origins = {
tuple(getattr(origin, k) for k in pk_cols): origin
for origin in self._listed_origins
}
ret = []
now = utcnow()
for pk, o in deduplicated_origins.items():
if pk in all_origins:
all_origins[pk] = o.evolve(last_seen=now)
else:
all_origins[pk] = o.evolve(last_seen=now, first_seen=now)
ret.append(all_origins[pk])
self._listed_origins = list(all_origins.values())
return ret
[docs]
def get_listed_origins(
self,
lister_id: Optional[UUID] = None,
url: Optional[str] = None,
urls: Optional[List[str]] = None,
enabled: Optional[bool] = True,
limit: int = 1000,
page_token: Optional[ListedOriginPageToken] = None,
) -> PaginatedListedOriginList:
origins = self._listed_origins
if lister_id:
origins = [o for o in origins if o.lister_id == lister_id]
urls_ = []
if url is not None:
urls_.append(url)
elif urls:
urls_ = urls
if urls_:
origins = [o for o in origins if o.url in urls_]
if enabled is not None:
origins = [o for o in origins if o.enabled == enabled]
if page_token is not None:
origins = [
o for o in origins if (str(o.lister_id), o.url) > tuple(page_token)
]
origins = origins[:limit]
if len(origins) == limit:
page_token = (str(origins[-1].lister_id), origins[-1].url)
else:
page_token = None
return PaginatedListedOriginList(origins, page_token)
[docs]
def get_visit_types_for_listed_origins(self, lister: Lister) -> List[str]:
return list(
{o.visit_type for o in self._listed_origins if o.lister_id == lister.id}
)
[docs]
def grab_next_visits(
self,
visit_type: str,
count: int,
policy: str,
enabled: bool = True,
lister_uuid: Optional[str] = None,
lister_name: Optional[str] = None,
lister_instance_name: Optional[str] = None,
timestamp: Optional[datetime.datetime] = None,
absolute_cooldown: Optional[datetime.timedelta] = datetime.timedelta(hours=12),
scheduled_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=7),
failed_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=14),
not_found_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=31),
tablesample: Optional[float] = None,
) -> List[ListedOrigin]:
if timestamp is None:
timestamp = utcnow()
origins = [
o
for o in self._listed_origins
if o.enabled == enabled and o.visit_type == visit_type
]
stats = {
(ovs.url, ovs.visit_type): ovs
for ovs in self.origin_visit_stats_get(
(o.url, o.visit_type) for o in origins
)
}
origins_stats = [(o, stats.get((o.url, o.visit_type))) for o in origins]
if absolute_cooldown:
# Don't schedule visits if they've been scheduled since the absolute cooldown
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s is None
or s.last_scheduled is None
or s.last_scheduled < (timestamp - absolute_cooldown)
]
if scheduled_cooldown:
# Don't re-schedule visits if they're already scheduled but we haven't
# recorded a result yet, unless they've been scheduled more than a week
# ago (it probably means we've lost them in flight somewhere).
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s is None
or (
s.last_scheduled is None
or s.last_scheduled
< max((timestamp - scheduled_cooldown), s.last_visit or epoch)
)
]
if failed_cooldown:
# Don't retry failed origins too often
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s is None
or s.last_visit_status is None
or s.last_visit_status.value != "failed"
or (
s.last_visit is not None
and s.last_visit < (timestamp - failed_cooldown)
)
]
if not_found_cooldown:
# Don't retry not found origins too often
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s is None
or s.last_visit_status is None
or s.last_visit_status.value != "not_found"
or (
s.last_visit is not None
and s.last_visit < (timestamp - not_found_cooldown)
)
]
if policy == "oldest_scheduled_first":
origins_stats.sort(
key=lambda e: e[1]
and e[1].last_scheduled is not None
and e[1].last_scheduled.timestamp()
or -1
)
elif policy == "never_visited_oldest_update_first":
# never visited origins have a NULL last_snapshot
origins_stats = [
(o, s)
for (o, s) in origins_stats
if (s is None or s.last_snapshot is None) and o.last_update is not None
]
origins_stats.sort(key=lambda e: e[0].last_update)
elif policy == "already_visited_order_by_lag":
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s
and s.last_snapshot is not None
and o.last_update is not None
and o.last_update > s.last_successful
]
origins_stats.sort(
key=lambda e: (
e[0].last_update - (e[1] and e[1].last_successful or epoch)
),
reverse=True,
)
elif policy == "origins_without_last_update":
origins_stats = [
(o, s) for (o, s) in origins_stats if o.last_update is None
]
origins_stats.sort(
key=lambda e: (
(e[1] and e[1].next_visit_queue_position) or -1,
e[0].first_seen,
)
)
for o, s in origins_stats:
self._visit_scheduler_queue_position[o.visit_type] = max(
self._visit_scheduler_queue_position.get(o.visit_type, 0),
s and s.next_visit_queue_position or 0,
)
elif policy == "first_visits_after_listing":
assert lister_uuid is not None or (
lister_name is not None and lister_instance_name is not None
), "first_visits_after_listing policy requires lister info "
if lister_uuid is not None:
listers = self.get_listers_by_id([lister_uuid])
lister = listers[0] if listers else None
else:
assert lister_name is not None
assert lister_instance_name is not None
lister = self.get_lister(lister_name, lister_instance_name)
assert (
lister is not None
), f"Lister with name {lister_name} and instance {lister_instance_name} not found !"
origins_stats = [
(o, s)
for (o, s) in origins_stats
if s is None
or s.last_scheduled is None
or (
lister.last_listing_finished_at
and s.last_scheduled < lister.last_listing_finished_at
)
]
origins_stats.sort(
key=lambda e: e[1] is not None
and e[1].last_scheduled is not None
and e[1].last_scheduled.timestamp()
or -1
)
else:
raise UnknownPolicy(f"Unknown scheduling policy {policy}")
if lister_uuid:
origins_stats = [
(o, s) for (o, s) in origins_stats if o.lister_id == lister_uuid
]
if lister_name:
listers = [_l.id for _l in self._listers if _l.name == lister_name]
origins_stats = [
(o, s) for (o, s) in origins_stats if o.lister_id in listers
]
if lister_instance_name:
listers = [
_l.id
for _l in self._listers
if _l.instance_name == lister_instance_name
]
origins_stats = [
(o, s) for (o, s) in origins_stats if o.lister_id in listers
]
ovs = [
OriginVisitStats(
url=o.url,
visit_type=o.visit_type,
last_scheduled=(
s and s.last_scheduled and max(s.last_scheduled, timestamp)
)
or timestamp,
)
for (o, s) in origins_stats
]
self.origin_visit_stats_upsert(ovs)
return [o for (o, _) in origins_stats]
[docs]
def create_tasks(
self,
tasks: List[Task],
policy: TaskPolicy = "recurring",
) -> List[Task]:
next_id = 0
if self._tasks:
next_id = max(t.id for t in self._tasks) + 1
_tasks = []
for t in tasks:
existing = [
u
for u in self._tasks
if u.type == t.type
and u.arguments == t.arguments
and u.policy == t.policy
and u.priority == t.priority
and u.status == t.status
and (u.policy != "oneshot" or u.next_run == t.next_run)
]
if existing:
assert len(existing) == 1
if existing[0] not in _tasks:
_tasks.append(existing[0])
continue
tt = self.get_task_type(t.type)
assert tt is not None
t = t.evolve(
id=next_id,
policy=t.policy or policy,
current_interval=t.current_interval or tt.default_interval,
retries_left=t.retries_left or tt.num_retries,
)
self._tasks.append(t)
_tasks.append(t)
next_id += 1
return _tasks
[docs]
def set_status_tasks(
self,
task_ids: List[int],
status: TaskStatus = "disabled",
next_run: Optional[datetime.datetime] = None,
) -> None:
if not task_ids:
return
tasks = [t for t in self._tasks if t.id in task_ids]
updated_tasks = [t.evolve(status=status) for t in tasks]
if next_run:
updated_tasks = [t.evolve(next_run=next_run) for t in updated_tasks]
self._tasks = [t for t in self._tasks if t.id not in task_ids]
self._tasks.extend(updated_tasks)
[docs]
def disable_tasks(self, task_ids: List[int]) -> None:
self.set_status_tasks(task_ids)
[docs]
def search_tasks(
self,
task_id: Optional[int] = None,
task_type: Optional[str] = None,
status: Optional[TaskStatus] = None,
priority: Optional[TaskPriority] = None,
policy: Optional[TaskPolicy] = None,
before: Optional[datetime.datetime] = None,
after: Optional[datetime.datetime] = None,
limit: Optional[int] = None,
) -> List[Task]:
tasks = list(self._tasks)
if task_id:
if isinstance(task_id, (str, int)):
tasks = [t for t in tasks if t.id == task_id]
else:
tasks = [t for t in tasks if t.id in task_id]
if task_type:
if isinstance(task_type, str):
tasks = [t for t in tasks if t.type == task_type]
else:
tasks = [t for t in tasks if t.type in task_type]
if status:
if isinstance(status, str):
tasks = [t for t in tasks if t.status == status]
else:
tasks = [t for t in tasks if t.status in status]
if priority:
if isinstance(priority, str):
tasks = [t for t in tasks if t.priority == priority]
else:
tasks = [t for t in tasks if t.priority in priority]
if policy:
tasks = [t for t in tasks if t.policy == policy]
if before:
tasks = [t for t in tasks if t.next_run <= before]
if after:
tasks = [t for t in tasks if t.next_run >= after]
if limit:
tasks = tasks[:limit]
return tasks
[docs]
def get_tasks(self, task_ids: List[int]) -> List[Task]:
ids = list(task_ids)
return [t for t in self._tasks if t.id in ids]
[docs]
def peek_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Task]:
if timestamp is None:
timestamp = utcnow()
tasks = list(
sorted(
[
t
for t in self._tasks
if t.type == task_type
and t.status == "next_run_not_scheduled"
and t.priority is None
and t.next_run <= timestamp
],
key=lambda t: t.next_run,
)
)
if num_tasks:
tasks = tasks[:num_tasks]
return tasks
[docs]
def grab_ready_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Task]:
if timestamp is None:
timestamp = utcnow()
tasks = self.peek_ready_tasks(task_type, timestamp, num_tasks)
ids = [t.id for t in tasks]
updated_tasks = [t.evolve(status="next_run_scheduled") for t in tasks]
self._tasks = [t for t in self._tasks if t.id not in ids]
self._tasks.extend(updated_tasks)
return updated_tasks
[docs]
def peek_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Task]:
if timestamp is None:
timestamp = utcnow()
tasks = list(
sorted(
[
t
for t in self._tasks
if t.type == task_type
and t.status == "next_run_not_scheduled"
and t.priority is not None
and t.next_run <= timestamp
],
key=lambda t: t.next_run,
)
)
if num_tasks:
tasks = tasks[:num_tasks]
return tasks
[docs]
def grab_ready_priority_tasks(
self,
task_type: str,
timestamp: Optional[datetime.datetime] = None,
num_tasks: Optional[int] = None,
) -> List[Task]:
if timestamp is None:
timestamp = utcnow()
tasks = self.peek_ready_priority_tasks(task_type, timestamp, num_tasks)
ids = [t.id for t in tasks]
updated_tasks = [t.evolve(status="next_run_scheduled") for t in tasks]
self._tasks = [t for t in self._tasks if t.id not in ids]
self._tasks.extend(updated_tasks)
return updated_tasks
[docs]
def schedule_task_run(
self,
task_id: int,
backend_id: str,
metadata: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime.datetime] = None,
) -> TaskRun:
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
next_id = 0
if self._task_runs:
max(_tr.id for _tr in self._task_runs) + 1
tr = TaskRun(
task=task_id,
id=next_id,
backend_id=backend_id,
metadata=metadata,
scheduled=timestamp,
status="scheduled",
)
self._task_runs.append(tr)
return tr
[docs]
def mass_schedule_task_runs(
self,
task_runs: List[TaskRun],
) -> None:
task_runs = [tr.evolve(status="scheduled") for tr in task_runs]
self._task_runs.extend(task_runs)
[docs]
def start_task_run(
self,
backend_id: str,
metadata: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime.datetime] = None,
) -> Optional[TaskRun]:
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
task_runs = [tr for tr in self._task_runs if tr.backend_id == backend_id]
updated_task_runs = [
tr.evolve(
started=timestamp,
status="started",
metadata={**(tr.metadata or {}), **metadata},
)
for tr in task_runs
]
self._task_runs = [tr for tr in self._task_runs if tr not in task_runs]
self._task_runs.extend(updated_task_runs)
if updated_task_runs:
return updated_task_runs[0]
logger.debug(
"Failed to mark task run %s as started",
backend_id,
)
return None
[docs]
def end_task_run(
self,
backend_id: str,
status: TaskRunStatus,
metadata: Optional[Dict[str, Any]] = None,
timestamp: Optional[datetime.datetime] = None,
) -> Optional[TaskRun]:
if metadata is None:
metadata = {}
if timestamp is None:
timestamp = utcnow()
task_runs = [tr for tr in self._task_runs if tr.backend_id == backend_id]
updated_task_runs = [
tr.evolve(
ended=timestamp,
status=status,
metadata={**(tr.metadata or {}), **metadata},
)
for tr in task_runs
]
self._task_runs = [tr for tr in self._task_runs if tr not in task_runs]
self._task_runs.extend(updated_task_runs)
if updated_task_runs:
return updated_task_runs[0]
logger.debug(
"Failed to mark task run %s as ended",
backend_id,
)
return None
[docs]
def filter_task_to_archive(
self,
after_ts: str,
before_ts: str,
limit: int = 10,
page_token: Optional[str] = None,
) -> Dict[str, Any]:
assert not page_token or isinstance(page_token, str)
last_id = -1 if page_token is None else int(page_token)
after_dt = datetime.datetime.fromisoformat(after_ts)
if after_dt.tzinfo is None:
after_dt = after_dt.replace(tzinfo=datetime.timezone.utc)
before_dt = datetime.datetime.fromisoformat(before_ts)
if before_dt.tzinfo is None:
before_dt = before_dt.replace(tzinfo=datetime.timezone.utc)
task_runs = []
for t in self._tasks:
if (
(t.policy == "oneshot" and t.status in ("completed", "disabled"))
or (t.policy == "recurring" and t.status == "disabled")
) and (t.id >= last_id):
task_runs.extend(
[
(t, tr)
for tr in self._task_runs
if (
tr.task == t.id
and (
(
tr.started is not None
and after_dt <= tr.started < before_dt
)
or (
tr.started is None
and (after_dt <= tr.scheduled < before_dt)
)
)
)
]
)
task_runs.sort(key=lambda x: (x[0].id, x[1].started))
tasks = [
{
"task_id": t.id,
"task_policy": t.policy,
"task_status": t.status,
"task_run_id": tr.id,
"arguments": t.arguments.to_dict(),
"type": t.type,
"backend_id": tr.backend_id,
"metadata": tr.metadata,
"scheduled": tr.scheduled,
"started": tr.started,
"ended": tr.ended,
"status": tr.status,
}
for (t, tr) in task_runs
]
for td in tasks:
td["arguments"]["args"] = {
i: v for i, v in enumerate(td["arguments"]["args"])
}
kwargs = td["arguments"]["kwargs"]
td["arguments"]["kwargs"] = json.dumps(kwargs)
if len(tasks) >= limit + 1: # remains data, add pagination information
result = {
"tasks": tasks[:limit],
"next_page_token": str(tasks[limit]["task_id"]),
}
else:
result = {"tasks": tasks}
return result
[docs]
def delete_archived_tasks(self, task_ids):
_task_ids = _task_run_ids = []
for task_id in task_ids:
_task_ids.append(task_id["task_id"])
_task_run_ids.append(task_id["task_run_id"])
self._task_runs = [tr for tr in self._task_runs if tr.task not in _task_ids]
self._tasks = [t for t in self._tasks if t.id not in _task_ids]
[docs]
def get_task_runs(
self,
task_ids: List[int],
limit: Optional[int] = None,
) -> List[TaskRun]:
if task_ids:
ret = [tr for tr in self._task_runs if tr.task in task_ids]
if limit:
ret = ret[:limit]
return ret
else:
return []
[docs]
def origin_visit_stats_upsert(
self,
origin_visit_stats: Iterable[OriginVisitStats],
) -> None:
# remove exact duplicates
ovs = []
for o in origin_visit_stats:
if o not in ovs:
ovs.append(o)
stats = {(o.url, o.visit_type): o for o in ovs}
if len(stats) < len(ovs):
raise SchedulerException("CardinalityViolation")
for key, o in stats.items():
if key not in self._origin_visit_stats:
self._origin_visit_stats[key] = o
else:
_o = self._origin_visit_stats[key]
_o = _o.evolve(
last_scheduled=o.last_scheduled or _o.last_scheduled,
last_snapshot=o.last_snapshot or _o.last_snapshot,
last_successful=o.last_successful or _o.last_successful,
last_visit=o.last_visit or _o.last_visit,
last_visit_status=o.last_visit_status or _o.last_visit_status,
next_visit_queue_position=o.next_visit_queue_position
or _o.next_visit_queue_position,
next_position_offset=o.next_position_offset
or _o.next_position_offset,
successive_visits=o.successive_visits or _o.successive_visits,
)
self._origin_visit_stats[key] = _o
[docs]
def origin_visit_stats_get(
self,
ids: Iterable[Tuple[str, str]],
) -> List[OriginVisitStats]:
if not ids:
return []
return [
self._origin_visit_stats[key]
for key in ids
if key in self._origin_visit_stats
]
[docs]
def visit_scheduler_queue_position_get(self) -> Dict[str, int]:
return self._visit_scheduler_queue_position.copy()
[docs]
def visit_scheduler_queue_position_set(
self,
visit_type: str,
position: int,
) -> None:
self._visit_scheduler_queue_position[visit_type] = position
[docs]
def update_metrics(
self,
lister_id: Optional[UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> List[SchedulerMetrics]:
if timestamp is None:
timestamp = utcnow()
origins = self._listed_origins
if lister_id:
origins = [lo for lo in origins if lo.lister_id == lister_id]
rows = []
keys = []
for lo in origins:
keys.append((lo.lister_id, lo.visit_type))
ovs = self._origin_visit_stats.get((lo.url, lo.visit_type), None)
rows.append(
(
lo.lister_id,
lo.visit_type,
lo.url,
lo.enabled,
ovs and ovs.last_snapshot,
lo.last_update,
ovs and ovs.last_successful,
)
)
metrics: Dict[Tuple[UUID, str], SchedulerMetrics] = {}
for lister_id, visit_type in keys:
_rows = [
row[2:] for row in rows if row[0] == lister_id and row[1] == visit_type
]
origins_known = len(_rows)
origins_enabled = len([row for row in _rows if row[1]])
origins_never_visited = len(
[row for row in _rows if row[1] and row[2] is None]
)
origins_with_pending_changes = len(
[
row
for row in _rows
if row[1]
and (row[4] is not None and row[3] is not None and row[3] > row[4])
]
)
assert lister_id is not None
assert visit_type is not None
metrics[(lister_id, visit_type)] = SchedulerMetrics(
lister_id=lister_id,
visit_type=visit_type,
last_update=timestamp,
origins_known=origins_known,
origins_enabled=origins_enabled,
origins_never_visited=origins_never_visited,
origins_with_pending_changes=origins_with_pending_changes,
)
self._scheduler_metrics.update(metrics)
return list(metrics.values())
[docs]
def get_metrics(
self,
lister_id: Optional[UUID] = None,
visit_type: Optional[str] = None,
) -> List[SchedulerMetrics]:
sms = list(self._scheduler_metrics.values())
if lister_id:
sms = [sm for sm in sms if sm.lister_id == lister_id]
if visit_type:
sms = [sm for sm in sms if sm.visit_type == visit_type]
return sms