"""NetworkX-backed graph engine for building and querying the knowledge graph."""
from __future__ import annotations
import json
import logging
from typing import Any
import networkx as nx
from .config import GraphMode
from .grounding import prepare_knowledge_graph
from .models import Edge, KnowledgeGraph, Node
logger = logging.getLogger(__name__)
[docs]
class GraphEngine:
"""Thin wrapper around a NetworkX MultiDiGraph that speaks our domain model."""
def __init__(self) -> None:
self.G: nx.MultiDiGraph = nx.MultiDiGraph()
# ── Mutations ────────────────────────────────────────────────────────
[docs]
def add_node(self, node: Node) -> None:
self.G.add_node(node.id, **node.model_dump(exclude_none=True, mode="json"))
def _edge_key(self, edge: Edge) -> str:
"""Return a stable edge key while preserving parallel edges."""
base_key = f"{edge.source}-{edge.type.value}-{edge.target}"
if not self.G.has_edge(edge.source, edge.target, base_key):
return base_key
suffix = 2
while self.G.has_edge(edge.source, edge.target, f"{base_key}-{suffix}"):
suffix += 1
return f"{base_key}-{suffix}"
[docs]
def add_edge(self, edge: Edge) -> None:
if edge.source not in self.G:
raise ValueError(f"Missing source node: {edge.source}")
if edge.target not in self.G:
raise ValueError(f"Missing target node: {edge.target}")
if edge.carcinogen and edge.carcinogen not in self.G:
raise ValueError(f"Missing carcinogen context node: {edge.carcinogen}")
self.G.add_edge(
edge.source,
edge.target,
key=self._edge_key(edge),
**edge.model_dump(exclude_none=True, mode="json"),
)
[docs]
def remove_node(self, node_id: str) -> None:
if node_id in self.G:
self.G.remove_node(node_id)
[docs]
def remove_edge(self, source: str, target: str, key: str | None = None) -> None:
if key is not None and self.G.has_edge(source, target, key):
self.G.remove_edge(source, target, key)
elif self.G.has_edge(source, target):
self.G.remove_edge(source, target)
# ── Bulk operations ──────────────────────────────────────────────────
def _validated_reference_graph(self) -> KnowledgeGraph | None:
if self.node_count == 0:
return None
current_graph = self.to_knowledge_graph()
validated_graph, _warnings = prepare_knowledge_graph(
current_graph,
mode=GraphMode.STRICT,
)
if not validated_graph.nodes:
return None
return validated_graph
[docs]
def load(self, kg: KnowledgeGraph, *, mode: GraphMode | str = GraphMode.EXPLORATORY) -> list[str]:
"""Replace the current graph with *kg*.
Clears all existing nodes and edges before loading.
Returns a list of warning messages for any skipped edges.
"""
self.clear()
return self.merge(kg, mode=mode)
[docs]
def merge(self, kg: KnowledgeGraph, *, mode: GraphMode | str = GraphMode.EXPLORATORY) -> list[str]:
"""Additive merge — new nodes/edges are added, existing ones updated.
Returns a list of warning messages for any skipped edges.
"""
reference_graphs: list[tuple[str, KnowledgeGraph]] = []
validated_graph = self._validated_reference_graph()
if validated_graph is not None:
reference_graphs.append(("current_graph", validated_graph))
prepared_graph, warnings = prepare_knowledge_graph(
kg,
mode=mode,
reference_graphs=reference_graphs or None,
)
for node in prepared_graph.nodes:
self.add_node(node)
for edge in prepared_graph.edges:
try:
self.add_edge(edge)
except ValueError as exc:
warnings.append(str(exc))
logger.warning("Skipped edge during merge: %s", exc)
return warnings
[docs]
def clear(self) -> None:
self.G.clear()
# ── Queries ──────────────────────────────────────────────────────────
@property
def node_count(self) -> int:
return int(self.G.number_of_nodes())
@property
def edge_count(self) -> int:
return int(self.G.number_of_edges())
[docs]
def get_node(self, node_id: str) -> dict[str, Any] | None:
if node_id in self.G:
return dict(self.G.nodes[node_id])
return None
[docs]
def neighbors(self, node_id: str) -> list[str]:
if node_id not in self.G:
return []
return list(self.G.successors(node_id)) + list(self.G.predecessors(node_id))
[docs]
def nodes_by_type(self, node_type: str) -> list[dict[str, Any]]:
return [
data
for _, data in self.G.nodes(data=True)
if data.get("type") == node_type
]
# ── Serialization ────────────────────────────────────────────────────
[docs]
def to_dict(self) -> dict[str, list[Any]]:
nodes = [dict(data) for _, data in self.G.nodes(data=True)]
edges = [dict(data) for _, _, _, data in self.G.edges(keys=True, data=True)]
return {"nodes": nodes, "edges": edges}
[docs]
def to_knowledge_graph(self) -> KnowledgeGraph:
data = self.to_dict()
return KnowledgeGraph(
nodes=[Node(**n) for n in data["nodes"]],
edges=[Edge(**e) for e in data["edges"]],
)
[docs]
def to_json(self, indent: int = 2) -> str:
return json.dumps(self.to_dict(), indent=indent, default=str)
# ── Validation ───────────────────────────────────────────────────────
[docs]
def validate(self) -> list[str]:
errors: list[str] = []
node_ids = set(self.G.nodes)
for u, v, data in self.G.edges(data=True):
if u not in node_ids:
errors.append(f"Edge references missing source node: {u}")
if v not in node_ids:
errors.append(f"Edge references missing target node: {v}")
if data.get("carcinogen") and data["carcinogen"] not in node_ids:
errors.append(
f"Edge '{u}→{v}' references carcinogen '{data['carcinogen']}' "
f"which is not in the graph"
)
return errors