Source code for swh.storage.cassandra.diagram
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""Generates a graphical representation of the Cassandra schema using
:mod:`swh.storage.cassandra.model`.
"""
import dataclasses
from typing import Tuple, Union
from . import model
[docs]
def dot_diagram() -> str:
"""Generates a diagram of the database in PlantUML format"""
import io
import textwrap
from .schema import HASH_ALGORITHMS
out = io.StringIO()
classes = {
cls.TABLE: cls for cls in model.__dict__.values() if hasattr(cls, "TABLE")
}
out.write(
textwrap.dedent(
"""
digraph g {
graph [
rankdir = "LR",
concentrate = true,
ratio = auto
];
node [
fontsize = "10",
shape = record
];
edge [
];
subgraph "logical_grouping" {
style = rounded;
bgcolor = gray95;
color = gray;
subgraph cluster_content {
label = <<b>content</b>>;
content;
content_by_sha1;
content_by_sha1_git;
content_by_sha256;
content_by_blake2s256;
}
subgraph cluster_skipped_content {
label = <<b>skipped_content</b>>;
skipped_content;
skipped_content_by_sha1;
skipped_content_by_sha1_git;
skipped_content_by_sha256;
skipped_content_by_blake2s256;
}
subgraph cluster_directory {
label = <<b>directories</b>>;
directory;
directory_entry;
}
subgraph cluster_revision {
label = <<b>revisions</b>>;
revision;
revision_parent;
}
subgraph cluster_release {
label = <<b>releases</b>>;
release;
}
subgraph cluster_snapshots {
label = <<b>snapshots</b>>;
snapshot;
snapshot_branch;
}
subgraph cluster_origins {
label = <<b>origins</b>>;
origin;
origin_visit;
origin_visit_status;
}
subgraph cluster_metadata {
label = <<b>metadata</b>>;
metadata_authority;
metadata_fetcher;
raw_extrinsic_metadata;
raw_extrinsic_metadata_by_id;
}
subgraph cluster_extid {
label = <<b>external identifiers</b>>;
extid;
extid_by_target;
}
}
"""
)
)
def write_table_header(table_name: str) -> None:
out.write(
f'"{table_name}" [shape = plaintext, label = < '
f'<TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0">'
# header row:
f'<TR ><TD PORT="ltcol0"> </TD> '
f'<TD bgcolor="grey90" border="1" COLSPAN="4"> \\N </TD> '
f'<TD PORT="rtcol0"></TD></TR>'
)
def get_target_field(field_full_name: str) -> Tuple[str, int]:
"""Given a string like 'table.col', returns the table name and the index of the column
within that table (1-indexed)"""
(points_to_table, points_to_col) = points_to.split(".")
try:
target_cls = classes[points_to_table]
except KeyError:
raise Exception(f"Unknown table {points_to_table}") from None
target_field_ids = [
i
for (i, field) in enumerate(dataclasses.fields(target_cls), start=1)
if field.name == points_to_col
]
try:
(target_field_id,) = target_field_ids
except ValueError:
raise Exception(
f"Expected exactly one field {target_cls.__name__}.{points_to_col}, "
f"got: {target_field_ids}"
) from None
return (points_to_table, target_field_id)
# write main tables
for cls in classes.values():
write_table_header(cls.TABLE)
for i, field in enumerate(dataclasses.fields(cls), start=1):
if field.name in cls.PARTITION_KEY:
assert (
field.name not in cls.CLUSTERING_KEY
), f"{field.name} is both PK and CK"
key = "PK"
elif field.name in cls.CLUSTERING_KEY:
key = "CK"
else:
key = ""
# TODO: use CQL types instead of Python types
ty = field.type
if getattr(ty, "__origin__", None) is Union:
assert (
len(ty.__args__) == 2 and type(None) in ty.__args__
), f"{cls.__name__}.{field.name} as unsupported type: {ty}"
# this is Optional[], unwrap it
(ty,) = [arg for arg in ty.__args__ if arg is not type(None)] # noqa
col_type = ty.__name__
out.write(
textwrap.dedent(
f"""
<TR><TD PORT="ltcol{i}" ></TD>
<TD align="left" > {field.name} </TD>
<TD align="left" > {col_type} </TD>
<TD align="left" > {key} </TD>
<TD align="left" PORT="rtcol{i}"> </TD></TR>
"""
)
)
out.write("</TABLE>> ];\n")
# add content_by_* and skipped_content_by_*, which don't have their own Python classes
for algo in HASH_ALGORITHMS:
for main_table in ("content", "skipped_content"):
write_table_header(f"{main_table}_by_{algo}")
out.write(
textwrap.dedent(
f"""
<TR><TD PORT="ltcol1" ></TD>
<TD align="left" > {algo} </TD>
<TD align="left" > bytes </TD>
<TD align="left" > PK </TD>
<TD align="left" PORT="rtcol1"> </TD></TR>
<TR><TD PORT="ltcol2" ></TD>
<TD align="left" > token </TD>
<TD align="left" > token </TD>
<TD align="left" > CK </TD>
<TD align="left" PORT="rtcol2"> </TD></TR>
"""
)
)
out.write("</TABLE>> ];\n")
out.write(
f'"{main_table}_by_{algo}":rtcol2 -> "{main_table}":ltcol0 [style = solid];\n'
)
# write "links" between tables
for cls_name, cls in classes.items():
for i, field in enumerate(dataclasses.fields(cls), start=1):
links = [] # pairs of (is_strong, target)
for points_to in field.metadata.get("fk") or []:
links.append((True, points_to))
for points_to in field.metadata.get("points_to") or []:
links.append((False, points_to))
for is_strong, points_to in links:
(target_table, target_field_id) = get_target_field(points_to)
if is_strong:
style = "[style = solid]"
else:
style = "[style = dashed]"
out.write(
f'"{cls.TABLE}":rtcol{i} -> "{target_table}":ltcol{target_field_id} '
f"{style};\n"
)
out.write("}\n")
return out.getvalue()