Source code for ExposoGraph.storage

"""SQLite-backed graph storage with revision history."""

from __future__ import annotations

import json
import sqlite3
from contextlib import suppress
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from types import TracebackType

from .config import GraphVisibility, normalize_graph_visibility
from .engine import GraphEngine
from .exporter import to_interactive_html_string
from .graph_filters import filter_knowledge_graph
from .models import KnowledgeGraph


def _utc_now() -> str:
    return datetime.now(timezone.utc).isoformat()


[docs] @dataclass(frozen=True) class GraphRevisionSummary: revision_id: int graph_key: str graph_name: str revision_number: int created_at: str node_count: int edge_count: int note: str | None = None visibility: GraphVisibility = GraphVisibility.ALL
[docs] @dataclass(frozen=True) class GraphRevision(GraphRevisionSummary): graph_json: str = "" html: str = ""
[docs] def to_knowledge_graph(self) -> KnowledgeGraph: return KnowledgeGraph(**json.loads(self.graph_json))
[docs] class GraphRepository: """Persist graphs and their revisions in a local SQLite database.""" def __init__(self, db_path: str | Path) -> None: self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._conn: sqlite3.Connection | None = self._create_connection() self._initialize() def _create_connection(self) -> sqlite3.Connection: conn = sqlite3.connect(str(self.db_path), check_same_thread=False) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") return conn @property def connection(self) -> sqlite3.Connection: """Return the persistent connection, reconnecting if closed.""" if self._conn is None: self._conn = self._create_connection() return self._conn try: self._conn.execute("SELECT 1") except (sqlite3.ProgrammingError, sqlite3.OperationalError): self._conn = self._create_connection() return self._conn
[docs] def close(self) -> None: """Close the persistent connection.""" conn, self._conn = self._conn, None if conn is None: return with suppress(sqlite3.ProgrammingError): conn.close()
def __enter__(self) -> GraphRepository: return self def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: self.close() def __del__(self) -> None: with suppress(Exception): self.close() def _initialize(self) -> None: with self.connection as conn: conn.executescript( """ CREATE TABLE IF NOT EXISTS graphs ( graph_key TEXT PRIMARY KEY, graph_name TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS graph_revisions ( revision_id INTEGER PRIMARY KEY AUTOINCREMENT, graph_key TEXT NOT NULL, graph_name TEXT NOT NULL, visibility TEXT NOT NULL DEFAULT 'all', revision_number INTEGER NOT NULL, created_at TEXT NOT NULL, note TEXT, node_count INTEGER NOT NULL, edge_count INTEGER NOT NULL, graph_json TEXT NOT NULL, html TEXT NOT NULL, UNIQUE(graph_key, revision_number), FOREIGN KEY(graph_key) REFERENCES graphs(graph_key) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_graph_revisions_graph ON graph_revisions(graph_key, revision_number DESC); """ ) columns = { row["name"] for row in conn.execute("PRAGMA table_info(graph_revisions)").fetchall() } if "visibility" not in columns: conn.execute( "ALTER TABLE graph_revisions " "ADD COLUMN visibility TEXT NOT NULL DEFAULT 'all'" ) @staticmethod def _summary_from_row(row: sqlite3.Row) -> GraphRevisionSummary: return GraphRevisionSummary( revision_id=row["revision_id"], graph_key=row["graph_key"], graph_name=row["graph_name"], visibility=normalize_graph_visibility(row["visibility"]), revision_number=row["revision_number"], created_at=row["created_at"], node_count=row["node_count"], edge_count=row["edge_count"], note=row["note"], ) @staticmethod def _revision_from_row(row: sqlite3.Row) -> GraphRevision: return GraphRevision( revision_id=row["revision_id"], graph_key=row["graph_key"], graph_name=row["graph_name"], visibility=normalize_graph_visibility(row["visibility"]), revision_number=row["revision_number"], created_at=row["created_at"], node_count=row["node_count"], edge_count=row["edge_count"], note=row["note"], graph_json=row["graph_json"], html=row["html"], )
[docs] def save_graph( self, *, graph_key: str, graph_name: str, kg: KnowledgeGraph, html: str, visibility: GraphVisibility | str = GraphVisibility.ALL, note: str | None = None, ) -> GraphRevisionSummary: timestamp = _utc_now() normalized_visibility = ( visibility if isinstance(visibility, GraphVisibility) else normalize_graph_visibility(visibility) ) graph_json = json.dumps(kg.model_dump(mode="json"), indent=2) with self.connection as conn: existing = conn.execute( "SELECT created_at FROM graphs WHERE graph_key = ?", (graph_key,), ).fetchone() created_at = existing["created_at"] if existing else timestamp conn.execute( """ INSERT INTO graphs(graph_key, graph_name, created_at, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(graph_key) DO UPDATE SET graph_name = excluded.graph_name, updated_at = excluded.updated_at """, (graph_key, graph_name, created_at, timestamp), ) cursor = conn.execute( """ INSERT INTO graph_revisions( graph_key, graph_name, visibility, revision_number, created_at, note, node_count, edge_count, graph_json, html ) SELECT ?, ?, ?, COALESCE(MAX(revision_number), 0) + 1, ?, ?, ?, ?, ?, ? FROM graph_revisions WHERE graph_key = ? """, ( graph_key, graph_name, normalized_visibility.value, timestamp, note, len(kg.nodes), len(kg.edges), graph_json, html, graph_key, ), ) row = conn.execute( "SELECT * FROM graph_revisions WHERE revision_id = ?", (cursor.lastrowid,), ).fetchone() return self._summary_from_row(row)
[docs] def save_engine( self, *, graph_key: str, graph_name: str, engine: GraphEngine, template_path: str | Path | None = None, visibility: GraphVisibility | str = GraphVisibility.ALL, note: str | None = None, ) -> GraphRevisionSummary: normalized_visibility = ( visibility if isinstance(visibility, GraphVisibility) else normalize_graph_visibility(visibility) ) kg = filter_knowledge_graph(engine.to_knowledge_graph(), normalized_visibility) html = to_interactive_html_string( engine, template_path=template_path, visibility=normalized_visibility, ) return self.save_graph( graph_key=graph_key, graph_name=graph_name, kg=kg, html=html, visibility=normalized_visibility, note=note, )
[docs] def list_graphs(self) -> list[GraphRevisionSummary]: with self.connection as conn: rows = conn.execute( """ SELECT r.* FROM graph_revisions r JOIN ( SELECT graph_key, MAX(revision_number) AS max_revision FROM graph_revisions GROUP BY graph_key ) latest ON latest.graph_key = r.graph_key AND latest.max_revision = r.revision_number ORDER BY r.created_at DESC, r.graph_name ASC """ ).fetchall() return [self._summary_from_row(row) for row in rows]
[docs] def list_revisions(self, graph_key: str) -> list[GraphRevisionSummary]: with self.connection as conn: rows = conn.execute( """ SELECT * FROM graph_revisions WHERE graph_key = ? ORDER BY revision_number DESC """, (graph_key,), ).fetchall() return [self._summary_from_row(row) for row in rows]
[docs] def get_revision(self, revision_id: int) -> GraphRevision | None: with self.connection as conn: row = conn.execute( "SELECT * FROM graph_revisions WHERE revision_id = ?", (revision_id,), ).fetchone() if row is None: return None return self._revision_from_row(row)
[docs] def get_latest_revision(self, graph_key: str) -> GraphRevision | None: with self.connection as conn: row = conn.execute( """ SELECT * FROM graph_revisions WHERE graph_key = ? ORDER BY revision_number DESC LIMIT 1 """, (graph_key,), ).fetchone() if row is None: return None return self._revision_from_row(row)