Source code for ExposoGraph.engine

"""NetworkX-backed graph engine for building and querying the knowledge graph."""

from __future__ import annotations

import json
import logging
from typing import Any

import networkx as nx

from .config import GraphMode
from .grounding import prepare_knowledge_graph
from .models import Edge, KnowledgeGraph, Node

logger = logging.getLogger(__name__)


[docs] class GraphEngine: """Thin wrapper around a NetworkX MultiDiGraph that speaks our domain model.""" def __init__(self) -> None: self.G: nx.MultiDiGraph = nx.MultiDiGraph() # ── Mutations ────────────────────────────────────────────────────────
[docs] def add_node(self, node: Node) -> None: self.G.add_node(node.id, **node.model_dump(exclude_none=True, mode="json"))
def _edge_key(self, edge: Edge) -> str: """Return a stable edge key while preserving parallel edges.""" base_key = f"{edge.source}-{edge.type.value}-{edge.target}" if not self.G.has_edge(edge.source, edge.target, base_key): return base_key suffix = 2 while self.G.has_edge(edge.source, edge.target, f"{base_key}-{suffix}"): suffix += 1 return f"{base_key}-{suffix}"
[docs] def add_edge(self, edge: Edge) -> None: if edge.source not in self.G: raise ValueError(f"Missing source node: {edge.source}") if edge.target not in self.G: raise ValueError(f"Missing target node: {edge.target}") if edge.carcinogen and edge.carcinogen not in self.G: raise ValueError(f"Missing carcinogen context node: {edge.carcinogen}") self.G.add_edge( edge.source, edge.target, key=self._edge_key(edge), **edge.model_dump(exclude_none=True, mode="json"), )
[docs] def remove_node(self, node_id: str) -> None: if node_id in self.G: self.G.remove_node(node_id)
[docs] def remove_edge(self, source: str, target: str, key: str | None = None) -> None: if key is not None and self.G.has_edge(source, target, key): self.G.remove_edge(source, target, key) elif self.G.has_edge(source, target): self.G.remove_edge(source, target)
# ── Bulk operations ────────────────────────────────────────────────── def _validated_reference_graph(self) -> KnowledgeGraph | None: if self.node_count == 0: return None current_graph = self.to_knowledge_graph() validated_graph, _warnings = prepare_knowledge_graph( current_graph, mode=GraphMode.STRICT, ) if not validated_graph.nodes: return None return validated_graph
[docs] def load(self, kg: KnowledgeGraph, *, mode: GraphMode | str = GraphMode.EXPLORATORY) -> list[str]: """Replace the current graph with *kg*. Clears all existing nodes and edges before loading. Returns a list of warning messages for any skipped edges. """ self.clear() return self.merge(kg, mode=mode)
[docs] def merge(self, kg: KnowledgeGraph, *, mode: GraphMode | str = GraphMode.EXPLORATORY) -> list[str]: """Additive merge — new nodes/edges are added, existing ones updated. Returns a list of warning messages for any skipped edges. """ reference_graphs: list[tuple[str, KnowledgeGraph]] = [] validated_graph = self._validated_reference_graph() if validated_graph is not None: reference_graphs.append(("current_graph", validated_graph)) prepared_graph, warnings = prepare_knowledge_graph( kg, mode=mode, reference_graphs=reference_graphs or None, ) for node in prepared_graph.nodes: self.add_node(node) for edge in prepared_graph.edges: try: self.add_edge(edge) except ValueError as exc: warnings.append(str(exc)) logger.warning("Skipped edge during merge: %s", exc) return warnings
[docs] def clear(self) -> None: self.G.clear()
# ── Queries ────────────────────────────────────────────────────────── @property def node_count(self) -> int: return int(self.G.number_of_nodes()) @property def edge_count(self) -> int: return int(self.G.number_of_edges())
[docs] def get_node(self, node_id: str) -> dict[str, Any] | None: if node_id in self.G: return dict(self.G.nodes[node_id]) return None
[docs] def neighbors(self, node_id: str) -> list[str]: if node_id not in self.G: return [] return list(self.G.successors(node_id)) + list(self.G.predecessors(node_id))
[docs] def nodes_by_type(self, node_type: str) -> list[dict[str, Any]]: return [ data for _, data in self.G.nodes(data=True) if data.get("type") == node_type ]
# ── Serialization ────────────────────────────────────────────────────
[docs] def to_dict(self) -> dict[str, list[Any]]: nodes = [dict(data) for _, data in self.G.nodes(data=True)] edges = [dict(data) for _, _, _, data in self.G.edges(keys=True, data=True)] return {"nodes": nodes, "edges": edges}
[docs] def to_knowledge_graph(self) -> KnowledgeGraph: data = self.to_dict() return KnowledgeGraph( nodes=[Node(**n) for n in data["nodes"]], edges=[Edge(**e) for e in data["edges"]], )
[docs] def to_json(self, indent: int = 2) -> str: return json.dumps(self.to_dict(), indent=indent, default=str)
# ── Validation ───────────────────────────────────────────────────────
[docs] def validate(self) -> list[str]: errors: list[str] = [] node_ids = set(self.G.nodes) for u, v, data in self.G.edges(data=True): if u not in node_ids: errors.append(f"Edge references missing source node: {u}") if v not in node_ids: errors.append(f"Edge references missing target node: {v}") if data.get("carcinogen") and data["carcinogen"] not in node_ids: errors.append( f"Edge '{u}{v}' references carcinogen '{data['carcinogen']}' " f"which is not in the graph" ) return errors