from dataclasses import dataclass, field
from enum import StrEnum
from itertools import product
from pathlib import Path
from string import ascii_uppercase, ascii_lowercase
from loguru import logger
import random

class Out(StrEnum): graph = "data/gt_graph_paper.txt"

@dataclass(slots=True, frozen=True)
class Edge:
    u: str
    v: str
    label: str
    def token(self) -> str: return f"{self.u}{self.label}{self.v}"

@dataclass(slots=True)
class Graph:
    nodes: set[str]
    edges: list[Edge] = field(default_factory=list)
    _pairs: set[tuple[str, str]] = field(default_factory=set)
    def add_edge(self, u: str, v: str, label: str) -> None:
        if u not in self.nodes or v not in self.nodes: raise ValueError("unknown node")
        if u == v: raise ValueError("no loops")
        a, b = (u, v) if u < v else (v, u)
        if (a, b) in self._pairs: raise ValueError("parallel edge")
        self._pairs.add((a, b)); self.edges.append(Edge(a, b, label))
    def write_tokens(self, path: Path) -> None:
        path.write_text("\n".join(e.token() for e in self.edges) + "\n", encoding="utf-8")

@dataclass(slots=True)
class Config:
    letters_per_node: int = 3
    avg_degree: float = 6.0
    seed: int | None = 7
    out: Path = Path(Out.graph)

def node_names(k: int) -> list[str]:
    if k < 1: raise ValueError("letters_per_node must be >= 1")
    return ["".join(p) for p in product(ascii_uppercase, repeat=k)]

def generate(cfg: Config) -> Graph:
    names = node_names(cfg.letters_per_node); n = len(names); total = n * (n - 1) // 2
    e = 0 if n < 2 else max(0, min(round(cfg.avg_degree * n / 2), total))
    logger.debug(f"seed={cfg.seed}, letters={cfg.letters_per_node}, n={n}, target_edges={e}, total_possible={total}")
    rng = random.Random(cfg.seed)
    g = Graph(nodes=set(names))
    if e:
        chosen: set[tuple[str, str]] = set()
        while len(chosen) < e:
            u, v = rng.sample(names, 2); a, b = (u, v) if u < v else (v, u)
            if (a, b) in chosen: continue
            chosen.add((a, b)); g.add_edge(a, b, rng.choice(ascii_lowercase))
    logger.info(f"built graph: n={n}, e={len(g.edges)}, avg_deg={0 if n==0 else round(2*len(g.edges)/n, 2)}")
    return g

if __name__ == "__main__":
    cfg = Config()
    graph = generate(cfg); graph.write_tokens(cfg.out)
    logger.info(f"wrote {cfg.out.resolve()}")
