"""Pydantic data models for the Carcinogen-Gene Knowledge Graph."""
from __future__ import annotations
import hashlib
import re
from enum import Enum
from typing import Optional
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator, model_validator
# ── Enums ────────────────────────────────────────────────────────────────
[docs]
class NodeType(str, Enum):
CARCINOGEN = "Carcinogen"
ENZYME = "Enzyme"
GENE = "Gene"
METABOLITE = "Metabolite"
DNA_ADDUCT = "DNA_Adduct"
PATHWAY = "Pathway"
TISSUE = "Tissue"
[docs]
class EdgeType(str, Enum):
ACTIVATES = "ACTIVATES"
DETOXIFIES = "DETOXIFIES"
TRANSPORTS = "TRANSPORTS"
FORMS_ADDUCT = "FORMS_ADDUCT"
REPAIRS = "REPAIRS"
PATHWAY = "PATHWAY"
EXPRESSED_IN = "EXPRESSED_IN"
INDUCES = "INDUCES"
INHIBITS = "INHIBITS"
ENCODES = "ENCODES"
CUSTOM = "CUSTOM"
[docs]
class CurationStatus(str, Enum):
DRAFT = "Draft"
IN_REVIEW = "In Review"
REVIEWED = "Reviewed"
APPROVED = "Approved"
REJECTED = "Rejected"
[docs]
class CurationConfidence(str, Enum):
LOW = "Low"
MEDIUM = "Medium"
HIGH = "High"
[docs]
class RecordOrigin(str, Enum):
IMPORTED = "imported"
SEEDED = "seeded"
USER = "user"
LLM = "llm"
[docs]
class MatchStatus(str, Enum):
UNKNOWN = "unknown"
CANONICAL = "canonical"
ALIAS = "alias"
UNMATCHED = "unmatched"
CUSTOM = "custom"
[docs]
class ProvenanceRecord(BaseModel):
model_config = ConfigDict(populate_by_name=True)
source_db: Optional[str] = None
record_id: Optional[str] = Field(
default=None,
validation_alias=AliasChoices("record_id", "accession"),
)
evidence: Optional[str] = None
pmid: Optional[str] = None
tissue: Optional[str] = None
citation: Optional[str] = None
url: Optional[str] = None
[docs]
class CurationRecord(BaseModel):
status: CurationStatus = CurationStatus.DRAFT
confidence: Optional[CurationConfidence] = None
curator: Optional[str] = None
reviewed_by: Optional[str] = None
reviewed_at: Optional[str] = None
notes: Optional[str] = None
@field_validator("reviewed_at")
@classmethod
def _validate_reviewed_at(cls, v: str | None) -> str | None:
if v is None:
return v
import datetime
for fmt in ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z"):
try:
datetime.datetime.strptime(v, fmt)
return v
except ValueError:
continue
raise ValueError(
f"reviewed_at must be a date string (YYYY-MM-DD or ISO 8601), got {v!r}"
)
def _join_unique(values: list[str]) -> str | None:
unique = dict.fromkeys(v for value in values if (v := value.strip()))
return "; ".join(unique) if unique else None
def _first_nonempty(values: list[str]) -> str | None:
for value in values:
cleaned = value.strip()
if cleaned:
return cleaned
return None
def _normalize_provenance_fields(owner: BaseModel, *, summary_only_fields: tuple[str, ...]) -> None:
provenance = list(getattr(owner, "provenance", []))
if not provenance:
legacy = ProvenanceRecord(
source_db=getattr(owner, "source_db", None),
evidence=getattr(owner, "evidence", None),
pmid=getattr(owner, "pmid", None),
tissue=getattr(owner, "tissue", None),
)
if any(legacy.model_dump(exclude_none=True).values()):
provenance = [legacy]
setattr(owner, "provenance", provenance)
if not provenance:
return
for field in ("source_db", "pmid", "tissue"):
if getattr(owner, field, None) is None:
joined = _join_unique(
[getattr(item, field) for item in provenance if getattr(item, field)]
)
setattr(owner, field, joined)
for field in summary_only_fields:
if getattr(owner, field, None) is None:
first_value = _first_nonempty(
[getattr(item, field) for item in provenance if getattr(item, field)]
)
setattr(owner, field, first_value)
# ── Node ─────────────────────────────────────────────────────────────────
[docs]
class Node(BaseModel):
id: str
label: str
type: NodeType
detail: str = ""
group: Optional[str] = None
iarc: Optional[str] = None
phase: Optional[str] = None
role: Optional[str] = None
reactivity: Optional[str] = None
source_db: Optional[str] = None
evidence: Optional[str] = None
pmid: Optional[str] = None
tissue: Optional[str] = None
variant: Optional[str] = None
phenotype: Optional[str] = None
activity_score: Optional[float] = None
tier: Optional[int] = None
origin: RecordOrigin = RecordOrigin.IMPORTED
match_status: MatchStatus = MatchStatus.UNKNOWN
canonical_id: Optional[str] = None
canonical_label: Optional[str] = None
canonical_namespace: Optional[str] = None
custom_type: Optional[str] = None
provenance: list[ProvenanceRecord] = Field(default_factory=list)
curation: Optional[CurationRecord] = None
[docs]
@classmethod
def generate_id(cls, label: str) -> str:
"""Derive a safe node ID from a label.
Simple labels (e.g. ``"CYP1A1"``) pass through unchanged. Labels
containing special characters that are stripped during sanitisation
get a short hash suffix to avoid collisions (e.g.
``"Benzo[a]pyrene"`` → ``"Benzo_a_pyrene_a4f2c1"``).
"""
sanitized = re.sub(r"[^A-Za-z0-9_.-]+", "_", label).strip("_.")
if not sanitized:
raise ValueError("Cannot generate an ID from an empty label")
# Append hash only when characters were actually stripped
if sanitized != label:
short_hash = hashlib.sha256(label.encode()).hexdigest()[:6]
sanitized = f"{sanitized}_{short_hash}"
return sanitized
@model_validator(mode="after")
def _normalize(self) -> "Node":
if not self.id:
self.id = self.generate_id(self.label)
_normalize_provenance_fields(self, summary_only_fields=("evidence",))
if self.match_status == MatchStatus.CANONICAL:
self.canonical_id = self.canonical_id or self.id
self.canonical_label = self.canonical_label or self.label
elif self.match_status == MatchStatus.ALIAS:
if not self.canonical_id:
raise ValueError("Alias-matched nodes must define canonical_id")
if not self.canonical_label:
raise ValueError("Alias-matched nodes must define canonical_label")
elif self.match_status == MatchStatus.CUSTOM and not self.custom_type:
raise ValueError("Custom nodes must define custom_type")
return self
# ── Edge ─────────────────────────────────────────────────────────────────
[docs]
class Edge(BaseModel):
source: str
target: str
type: EdgeType
label: Optional[str] = None
carcinogen: Optional[str] = None
source_db: Optional[str] = None
evidence: Optional[str] = None
pmid: Optional[str] = None
tissue: Optional[str] = None
origin: RecordOrigin = RecordOrigin.IMPORTED
match_status: MatchStatus = MatchStatus.UNKNOWN
canonical_predicate: Optional[str] = None
canonical_namespace: Optional[str] = None
custom_predicate: Optional[str] = None
provenance: list[ProvenanceRecord] = Field(default_factory=list)
curation: Optional[CurationRecord] = None
@model_validator(mode="after")
def _normalize(self) -> "Edge":
_normalize_provenance_fields(self, summary_only_fields=("evidence",))
if self.type == EdgeType.CUSTOM and not self.custom_predicate:
raise ValueError("Edges with type CUSTOM must define custom_predicate")
if self.type == EdgeType.CUSTOM and self.match_status in (
MatchStatus.CANONICAL,
MatchStatus.ALIAS,
):
raise ValueError("Edges with type CUSTOM cannot be canonical or alias-matched")
if self.match_status in (MatchStatus.CANONICAL, MatchStatus.ALIAS):
self.canonical_predicate = self.canonical_predicate or self.type.value
elif self.match_status == MatchStatus.CUSTOM and not self.custom_predicate:
raise ValueError("Custom edges must define custom_predicate")
return self
# ── Top-level graph container ────────────────────────────────────────────
[docs]
class KnowledgeGraph(BaseModel):
nodes: list[Node] = Field(default_factory=list)
edges: list[Edge] = Field(default_factory=list)
@model_validator(mode="after")
def _check_edge_references(self) -> "KnowledgeGraph":
node_ids = {n.id for n in self.nodes}
bad: list[str] = []
for edge in self.edges:
if edge.source not in node_ids:
bad.append(f"Edge references missing source node: {edge.source!r}")
if edge.target not in node_ids:
bad.append(f"Edge references missing target node: {edge.target!r}")
if edge.carcinogen and edge.carcinogen not in node_ids:
bad.append(f"Edge references missing carcinogen node: {edge.carcinogen!r}")
if bad:
raise ValueError(
f"Referential integrity errors ({len(bad)}):\n " + "\n ".join(bad)
)
return self