# generate_paths_levels.py
from __future__ import annotations

from dataclasses import dataclass
from collections import deque, defaultdict
from pathlib import Path
from string import ascii_lowercase, ascii_uppercase
from loguru import logger
from tqdm.auto import tqdm
import json, math, random, numpy as np
from concurrent.futures import ProcessPoolExecutor
from os import cpu_count

# ---------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------
@dataclass(slots=True)
class CFG:
    graph_path: str = "data/gt_graph_paper.txt"
    out_dir: str = "data/paths_paper"
    seed: int = 7

    # Eval: base difficulty by hops, then add constraints by level
    max_hops_eval: int = 6
    eval_per_hop: int = 128
    levels_max: int = 6
    p_inc: float = 0.5

    # Train: size and search depth
    n_train_extra: int = 500_000
    max_hops_train: int = 10

    # Difficulty weights (α_h, α_I, α_X, α_r)
    diff_w_hops: float = 1.0
    diff_w_inc: float = 0.8
    diff_w_exc: float = 0.8
    diff_w_rare: float = 1.0
    weight_sigma: float = 0.7  # label rarity sampling

# ---------------------------------------------------------------------
# Graph + utils
# ---------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class Edge:
    u: str
    l: str
    v: str

def parse_edge_token(t: str) -> Edge:
    i = next(j for j, ch in enumerate(t) if ch in ascii_lowercase)
    u, l, v = t[:i], t[i], t[i + 1 :]
    assert u and v and all(c in ascii_uppercase for c in u + v)
    return Edge(u, l, v)

def load_graph(path: Path) -> tuple[list[str], dict[str, list[tuple[str, str]]]]:
    rows = [ln.strip() for ln in path.read_text(encoding="utf-8").splitlines() if ln.strip()]
    nodes: set[str] = set(); adj: dict[str, list[tuple[str, str]]] = defaultdict(list)
    for ln in rows:
        e = parse_edge_token(ln)
        nodes.update((e.u, e.v)); adj[e.u].append((e.l, e.v)); adj[e.v].append((e.l, e.u))
    logger.info(f"graph loaded: |V|={len(nodes)} |E|={sum(len(v) for v in adj.values())//2}")
    return sorted(nodes), adj

def path_string(nodes: list[str], labels: list[str]) -> str:
    return nodes[0] + "".join(l + n for l, n in zip(labels, nodes[1:]))

def fmt_query(u: str, v: str, inc: set[str], exc: set[str], nodes: list[str], labels: list[str]) -> str:
    ib = f"[{''.join(sorted(inc))}]" if inc else ""
    eb = f"({''.join(sorted(exc))})" if exc else ""
    return f"{u}-{v}{ib}{eb}:{path_string(nodes, labels)}"

def softmax(z: np.ndarray) -> np.ndarray:
    z = z - z.max(); e = np.exp(z); return e / e.sum()

def sample_label_weights(rng: np.random.Generator, sigma: float) -> dict[str, float]:
    z = rng.normal(0.0, sigma, 26)
    p = 0.95 * softmax(z) + 0.05 / 26.0
    return {ascii_lowercase[i]: float(p[i]) for i in range(26)}

# ---------------------------------------------------------------------
# Search
# ---------------------------------------------------------------------
IDX = {c: i for i, c in enumerate(ascii_lowercase)}

def bfs_exact_hops(adj: dict[str, list[tuple[str, str]]], u: str, v: str, hops: int, rng: random.Random) -> tuple[list[str], list[str]] | None:
    start = (u, 0); dq = deque([start]); par: dict[tuple[str, int], tuple[tuple[str, int], str] | None] = {start: None}
    while dq:
        x, d = dq.popleft()
        if x == v and d == hops:
            nodes, labels = [x], []; cur = (x, d)
            while par[cur] is not None:
                prev, l = par[cur]; labels.append(l); nodes.append(prev[0]); cur = prev
            nodes.reverse(); labels.reverse(); return nodes, labels
        if d >= hops: continue
        nbrs = list(adj.get(x, [])); 
        if nbrs: rng.shuffle(nbrs)
        for l, y in nbrs:
            s2 = (y, d + 1)
            if s2 in par: continue
            par[s2] = ((x, d), l); dq.append(s2)
    return None

def bfs_constrained(adj: dict[str, list[tuple[str, str]]], u: str, v: str, include: set[str], exclude: set[str], max_hops: int) -> tuple[list[str], list[str]] | None:
    req = 0
    for l in include: req |= 1 << IDX[l]
    start = (u, 0); dq = deque([start])
    par: dict[tuple[str, int], tuple[tuple[str, int], str] | None] = {start: None}
    depth: dict[tuple[str, int], int] = {start: 0}
    goal: tuple[str, int] | None = None
    while dq:
        x, m = dq.popleft()
        if x == v and (m & req) == req:
            goal = (x, m); break
        d = depth[(x, m)]
        if d >= max_hops: continue
        for l, y in adj.get(x, []):
            if l in exclude: continue
            m2 = m | (1 << IDX[l]) if l in include else m
            s2 = (y, m2)
            if s2 in par: continue
            par[s2] = ((x, m), l); depth[s2] = d + 1; dq.append(s2)
    if goal is None: return None
    nodes, labels = [goal[0]], []; cur = goal
    while par[cur] is not None:
        prev, l = par[cur]; labels.append(l); nodes.append(prev[0]); cur = prev
    nodes.reverse(); labels.reverse(); return nodes, labels

def bfs_constrained_exact_hops(adj: dict[str, list[tuple[str, str]]], u: str, v: str, include: set[str], exclude: set[str], hops: int, rng: random.Random) -> tuple[list[str], list[str]] | None:
    req = 0
    for l in include: req |= 1 << IDX[l]
    start = (u, 0, 0); dq = deque([start]); par: dict[tuple[str, int, int], tuple[tuple[str, int, int], str] | None] = {start: None}
    while dq:
        x, d, m = dq.popleft()
        if x == v and d == hops and (m & req) == req:
            nodes, labels = [x], []; cur = (x, d, m)
            while par[cur] is not None:
                prev, l = par[cur]; labels.append(l); nodes.append(prev[0]); cur = prev
            nodes.reverse(); labels.reverse(); return nodes, labels
        if d >= hops: continue
        nbrs = list(adj.get(x, []))
        if nbrs: rng.shuffle(nbrs)
        for l, y in nbrs:
            if l in exclude: continue
            m2 = m | (1 << IDX[l]) if l in include else m
            s2 = (y, d + 1, m2)
            if s2 in par: continue
            par[s2] = ((x, d, m), l); dq.append(s2)
    return None

# ---------------------------------------------------------------------
# Difficulty (multiplicative) and rarity
# ---------------------------------------------------------------------
def rarity(labels: list[str], w: dict[str, float]) -> float:
    if not labels: return 0.0
    return float(sum(-math.log(w[l]) for l in labels) / len(labels))

def diff_eval(hops: int, inc: set[str], exc: set[str], labels: list[str], w: dict[str, float], cfg: CFG) -> float:
    nov = cfg.diff_w_hops * hops + cfg.diff_w_rare * rarity(labels, w)
    util = (1.0 + cfg.diff_w_inc * len(inc)) * (1.0 + cfg.diff_w_exc * len(exc))
    return float(nov * util)

def diff_train(labels: list[str], inc: set[str], exc: set[str], w: dict[str, float], cfg: CFG) -> float:
    nov = cfg.diff_w_hops * len(labels) + cfg.diff_w_rare * rarity(labels, w)
    util = (1.0 + cfg.diff_w_inc * len(inc)) * (1.0 + cfg.diff_w_exc * len(exc))
    return float(nov * util)

# ---------------------------------------------------------------------
# Constraint planning for eval levels
# ---------------------------------------------------------------------
def constraints_for_level(labels_in_path: list[str], level: int, rng: random.Random, p_inc: float) -> tuple[set[str], set[str]]:
    if level <= 1: return set(), set()
    want = level - 1
    in_path = set(labels_in_path); not_in_path = [c for c in ascii_lowercase if c not in in_path]
    inc, exc = set(), set()
    for _ in range(want):
        choose_inc = rng.random() < p_inc
        if choose_inc and len(inc) < len(in_path):
            inc.add(rng.choice(list(in_path - inc)))
        else:
            if len(exc) < len(not_in_path):
                exc.add(rng.choice([c for c in not_in_path if c not in exc]))
            elif len(inc) < len(in_path):
                inc.add(rng.choice(list(in_path - inc)))
    return inc, exc

# ---------------------------------------------------------------------
# Edge coverage for train
# ---------------------------------------------------------------------
def cover_edges_for_train(adj: dict[str, list[tuple[str, str]]]) -> list[tuple[str, str, set[str], set[str], list[str], list[str]]]:
    seen = set(); ex = []
    for u, nbrs in adj.items():
        for l, v in nbrs:
            if (u, l, v) in seen: continue
            for a, b in ((u, v), (v, u)):
                ex.append((a, b, {l}, set(), [a, b], [l]))
            seen.add((u, l, v)); seen.add((v, l, u))
    return ex

# ---------------------------------------------------------------------
# Generators
# ---------------------------------------------------------------------
_G_ADJ: dict[str, list[tuple[str, str]]] | None = None

def _init_pool(adj: dict[str, list[tuple[str, str]]]) -> None:
    global _G_ADJ
    _G_ADJ = adj

def _bfs_constrained_worker(u: str, v: str, inc: set[str], exc: set[str], max_hops: int) -> tuple[list[str], list[str]] | None:
    # Uses process-local global adjacency set via initializer
    assert _G_ADJ is not None
    return bfs_constrained(_G_ADJ, u, v, inc, exc, max_hops)

def gen_eval(nodes: list[str], adj: dict[str, list[tuple[str, str]]], w: dict[str, float], cfg: CFG, rng: random.Random, nrng: np.random.Generator) -> tuple[list[str], list[tuple]]:
    eval_lines: list[str] = []; meta: list[tuple] = []
    per_hop_target = {h: cfg.eval_per_hop for h in range(1, cfg.max_hops_eval + 1)}
    per_hop_counts = {h: 0 for h in per_hop_target}
    attempts = 0
    logger.info(f"eval planning: hops=1..{cfg.max_hops_eval}, per_hop={cfg.eval_per_hop}, levels={cfg.levels_max}")
    with tqdm(total=sum(per_hop_target.values()), desc="Eval bases") as pbar:
        while any(per_hop_counts[h] < per_hop_target[h] for h in per_hop_target) and attempts < 50_000_000:
            h = rng.randint(1, cfg.max_hops_eval)
            if per_hop_counts[h] >= per_hop_target[h]: continue
            u, v = rng.sample(nodes, 2)
            res = bfs_exact_hops(adj, u, v, h, rng); attempts += 1
            if res is None: continue
            base_nodes, base_labels = res; labels_in_path = base_labels[:]
            for lvl in range(1, cfg.levels_max + 1):
                inc, exc = constraints_for_level(labels_in_path, lvl, rng, cfg.p_inc)
                if lvl == 1:
                    nds, lbs = base_nodes, base_labels
                else:
                    r2 = bfs_constrained_exact_hops(adj, u, v, inc, exc, h, rng); assert r2 is not None
                    nds, lbs = r2
                q = fmt_query(u, v, inc, exc, nds, lbs)
                d = diff_eval(h, inc, exc, lbs, w, cfg)
                eval_lines.append(f"{q} #d={d:.3f}")
                meta.append((u, v, h, lvl, "".join(sorted(inc)), "".join(sorted(exc)), path_string(nds, lbs), round(d, 3)))
            per_hop_counts[h] += 1; pbar.update(1)
    logger.info(f"eval built: total_examples={len(eval_lines)} (bases={sum(per_hop_counts.values())}, levels_each={cfg.levels_max})")
    return eval_lines, meta

def draw_label_set(nrng: np.random.Generator, w: dict[str, float], k: int, banned: set[str] = set()) -> set[str]:
    if k == 0: return set()
    letters = [c for c in ascii_lowercase if c not in banned]
    probs = np.array([w[c] for c in letters], dtype=float); probs /= probs.sum()
    picks = nrng.choice(len(letters), size=min(k, len(letters)), replace=False, p=probs)
    return {letters[i] for i in picks}

def sample_k(nrng: np.random.Generator, kmax: int, p: float) -> int:
    k = int(nrng.geometric(p) - 1); return 0 if k < 0 else (kmax if k > kmax else k)

def gen_train(nodes: list[str], adj: dict[str, list[tuple[str, str]]], w: dict[str, float], cfg: CFG, rng: random.Random, nrng: np.random.Generator, forbid_pairs: set[tuple[str, str]]) -> list[str]:
    lines: list[str] = []
    for a, b, inc, exc, nds, lbs in tqdm(cover_edges_for_train(adj), desc="Train coverage"):
        if tuple(sorted((a, b))) in forbid_pairs: continue
        q = fmt_query(a, b, inc, exc, nds, lbs); d = diff_train(lbs, inc, exc, w, cfg)
        lines.append(f"{q} #d={d:.3f}")
    target = cfg.n_train_extra + len(lines)
    n_extra = cfg.n_train_extra
    if n_extra <= 0:
        logger.info(f"train built: total_examples={len(lines)}")
        return lines

    workers = max(1, (cpu_count() or 1))
    buf = workers * 4
    logger.info(f"parallel train extra: workers={workers}, buffer={buf}")

    from collections import deque as _dq
    pending: _dq[tuple] = _dq()

    with ProcessPoolExecutor(max_workers=workers, initializer=_init_pool, initargs=(adj,)) as ex, tqdm(total=max(n_extra, 0), desc="Train extra") as pbar:
        while len(lines) < target:
            while len(pending) < buf:
                u, v = rng.sample(nodes, 2)
                if tuple(sorted((u, v))) in forbid_pairs: continue
                kin = sample_k(nrng, 3, 0.6); kex = sample_k(nrng, 3, 0.7)
                inc = draw_label_set(nrng, w, kin); exc = draw_label_set(nrng, w, kex, banned=inc)
                fut = ex.submit(_bfs_constrained_worker, u, v, inc, exc, cfg.max_hops_train)
                pending.append((fut, u, v, inc, exc))

            fut, u, v, inc, exc = pending[0]
            res = fut.result()
            pending.popleft()
            if res is None: continue
            nds, lbs = res
            q = fmt_query(u, v, inc, exc, nds, lbs); d = diff_train(lbs, inc, exc, w, cfg)
            lines.append(f"{q} #d={d:.3f}"); pbar.update(1)

    logger.info(f"train built: total_examples={len(lines)}")
    return lines

# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
def main() -> None:
    cfg = CFG()
    random.seed(cfg.seed); nrng = np.random.default_rng(cfg.seed); rng = random.Random(cfg.seed)
    out = Path(cfg.out_dir); out.mkdir(parents=True, exist_ok=True)

    nodes, adj = load_graph(Path(cfg.graph_path))
    w = sample_label_weights(nrng, cfg.weight_sigma)
    (out / "label_weights.json").write_text(json.dumps(w, indent=2), encoding="utf-8")
    logger.info(f"label weights saved: {(out / 'label_weights.json').resolve()}")

    eval_lines, meta = gen_eval(nodes, adj, w, cfg, rng, nrng)
    holdout_pairs = {tuple(sorted((u, v))) for (u, v, *_rest) in meta}
    train_lines = gen_train(nodes, adj, w, cfg, rng, nrng, forbid_pairs=holdout_pairs)

    (out / "eval.txt").write_text("\n".join(eval_lines) + "\n", encoding="utf-8")
    (out / "train.txt").write_text("\n".join(train_lines) + "\n", encoding="utf-8")
    (out / "eval_meta.csv").write_text(
        "u,v,hops,level,include,exclude,path,difficulty\n"
        + "\n".join(f"{u},{v},{h},{lvl},{inc},{exc},{p},{d:.3f}" for (u, v, h, lvl, inc, exc, p, d) in meta)
        + "\n",
        encoding="utf-8",
    )
    logger.info(f"wrote {out}/train.txt, {out}/eval.txt, {out}/eval_meta.csv")

if __name__ == "__main__":
    main()
