import os
import json
import time
import argparse
import logging
from typing import List, Tuple, Dict, Optional

import numpy as np
import networkx as nx

import torch
from torch.utils.data import Dataset, DataLoader
import dgl

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, recall_score, f1_score, average_precision_score

from config import DATASET_REGISTRY, NUM_NODE_TYPES
from ast_processor import ASTProcessor
from data_loader import load_graph_structure
from motif_miner import MotifMiner, MotifInstance
from graph_builder import GraphEngine
from model import RMGNN, rmgnn_loss

logger = logging.getLogger("train")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


# -----------------------------------------------------------------------------
# Metrics: AP / Acc / Rec / F1
# -----------------------------------------------------------------------------
def compute_metrics(y_true: np.ndarray, y_prob_pos: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    return {
        "AP": float(average_precision_score(y_true, y_prob_pos)),
        "Acc": float(accuracy_score(y_true, y_pred)),
        "Rec": float(recall_score(y_true, y_pred, zero_division=0)),
        "F1": float(f1_score(y_true, y_pred, zero_division=0)),
    }


# -----------------------------------------------------------------------------
# WL subtree kernel on motif induced subgraphs (strict teacher)
# -----------------------------------------------------------------------------
def _stable_hash(s: str) -> str:
    import hashlib
    return hashlib.md5(s.encode("utf-8")).hexdigest()


def _edge_type(g: nx.Graph, u: int, v: int) -> int:
    d = g.edges[u, v]
    return int(d.get("type", 1))


def wl_feature_map(g: nx.Graph, h: int = 2, use_edge_type: bool = True) -> Dict[str, int]:
    """
    WL subtree features for a motif induced subgraph.

    Initial labels:
      - node type decoded from one-hot prefix in node['attr'].

    Neighborhood aggregation:
      - multiset of neighbor labels, optionally prefixed by edge relation.
      - sorted to ensure determinism.

    Returns:
      - sparse feature map (label string -> count)
    """
    # initial labels
    cur: Dict[int, str] = {}
    for n, data in g.nodes(data=True):
        feat = data.get("attr", [])
        if isinstance(feat, list) and len(feat) >= NUM_NODE_TYPES and sum(feat[:NUM_NODE_TYPES]) > 0:
            t = int(np.argmax(np.asarray(feat[:NUM_NODE_TYPES], dtype=np.float32)))
        else:
            t = 0
        cur[n] = f"T{t}"

    phi: Dict[str, int] = {}
    for lab in cur.values():
        phi[lab] = phi.get(lab, 0) + 1

    for it in range(1, h + 1):
        nxt: Dict[int, str] = {}
        for v in g.nodes():
            neigh_sig = []
            for u in g.neighbors(v):
                if use_edge_type:
                    et = _edge_type(g, v, u)
                    neigh_sig.append(f"{et}:{cur[u]}")
                else:
                    neigh_sig.append(cur[u])

            neigh_sig.sort()
            sig = cur[v] + "|" + "|".join(neigh_sig)
            lab = f"WL{it}_" + _stable_hash(sig)
            nxt[v] = lab
            phi[lab] = phi.get(lab, 0) + 1
        cur = nxt

    return phi


def wl_kernel_matrix(graphs: List[nx.Graph], h: int = 2, use_edge_type: bool = True) -> np.ndarray:
    """
    Kernel matrix K_ij = <phi(G_i), phi(G_j)> for motif subgraphs in one sample.
    """
    feats = [wl_feature_map(g, h=h, use_edge_type=use_edge_type) for g in graphs]

    vocab: Dict[str, int] = {}
    for fm in feats:
        for k in fm.keys():
            if k not in vocab:
                vocab[k] = len(vocab)

    X = np.zeros((len(graphs), len(vocab)), dtype=np.float32)
    for i, fm in enumerate(feats):
        for k, c in fm.items():
            X[i, vocab[k]] = float(c)

    return X @ X.T


def teacher_distribution_from_motifs(
    motif_subgraphs: List[nx.Graph],
    tau: float,
    device: torch.device,
    wl_h: int = 2,
    use_edge_type: bool = True
) -> Optional[torch.Tensor]:
    """
    Strict teacher distribution computed from WL kernel on motif induced subgraphs.

    T_ij = softmax(K_ij / tau) row-wise.

    Returns:
      - Tensor [M, M] if M >= 2
      - None if M <= 1
    """
    M = len(motif_subgraphs)
    if M <= 1:
        return None
    K = wl_kernel_matrix(motif_subgraphs, h=wl_h, use_edge_type=use_edge_type)
    Kt = torch.tensor(K, dtype=torch.float32, device=device)
    return torch.softmax(Kt / float(tau), dim=1)


# -----------------------------------------------------------------------------
# Dataset wrapper: carries nx_graph for strict teacher
# -----------------------------------------------------------------------------
class SampleDataset(Dataset):
    """
    Each item contains:
      - dgl_graph : DGLGraph with motif super-nodes
      - label     : int
      - nx_graph  : original NetworkX HRG (strict teacher source)
      - insts     : list[MotifInstance] extracted from nx_graph
    """
    def __init__(
        self,
        dgl_graphs: List[dgl.DGLGraph],
        labels: np.ndarray,
        nx_graphs: List[nx.Graph],
        insts: List[List[MotifInstance]],
    ):
        assert len(dgl_graphs) == len(labels) == len(nx_graphs) == len(insts)
        self.dgl_graphs = dgl_graphs
        self.labels = labels.astype(np.int64)
        self.nx_graphs = nx_graphs
        self.insts = insts

    def __len__(self) -> int:
        return len(self.dgl_graphs)

    def __getitem__(self, idx: int):
        return self.dgl_graphs[idx], int(self.labels[idx]), self.nx_graphs[idx], self.insts[idx]


def collate_fn(batch):
    graphs, ys, nx_gs, insts = zip(*batch)
    bg = dgl.batch(list(graphs))
    y = torch.tensor(ys, dtype=torch.long)
    return bg, y, list(nx_gs), list(insts)


# -----------------------------------------------------------------------------
# Build per-sample DGL graphs (motif super-nodes included)
# -----------------------------------------------------------------------------
def build_per_sample_dgl(
    nx_graphs: List[nx.Graph],
    insts_list: List[List[MotifInstance]],
) -> Tuple[List[dgl.DGLGraph], np.ndarray]:
    dgl_graphs: List[dgl.DGLGraph] = []
    labels: List[int] = []

    for g, insts in zip(nx_graphs, insts_list):
        engine = GraphEngine([g], [insts])
        dg, y = engine.build_batched_graph()
        dgl_graphs.append(dg)
        labels.append(int(y[0]))

    return dgl_graphs, np.asarray(labels, dtype=np.int64)


# -----------------------------------------------------------------------------
# One epoch (train or eval): strict teacher from original nx_graph
# -----------------------------------------------------------------------------
def run_epoch(
    model: RMGNN,
    loader: DataLoader,
    optimizer: Optional[torch.optim.Optimizer],
    device: torch.device,
    alpha: float,
    tau: float,
    thr: float = 0.5,
    wl_h: int = 2,
    use_edge_type: bool = True,
) -> Tuple[float, Dict[str, float]]:
    is_train = optimizer is not None
    model.train(is_train)

    all_y: List[int] = []
    all_prob: List[float] = []
    all_pred: List[int] = []
    losses: List[float] = []

    for bg, y, nx_batch, insts_batch in loader:
        bg = bg.to(device)
        y = y.to(device)

        logits, student_pack = model(bg)

        # Strict teacher pack: computed per sample using the original HRG and induced motif subgraphs
        teacher_pack: List[Optional[torch.Tensor]] = []
        for nx_g, insts in zip(nx_batch, insts_batch):
            motif_subgraphs: List[nx.Graph] = []
            for inst in insts:
                nodes = list(inst.nodes)
                if len(nodes) <= 1:
                    continue
                # Strict: induced subgraph in the original HRG
                mg = nx_g.subgraph(nodes).copy()
                motif_subgraphs.append(mg)

            T = teacher_distribution_from_motifs(
                motif_subgraphs, tau=tau, device=device, wl_h=wl_h, use_edge_type=use_edge_type
            )
            teacher_pack.append(T)

        loss = rmgnn_loss(logits, y, student_pack, teacher_pack, alpha=alpha)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        losses.append(float(loss.detach().cpu().item()))

        prob_pos = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy().astype(float)
        pred = (prob_pos >= float(thr)).astype(int)
        gt = y.detach().cpu().numpy().astype(int)

        all_y.extend(gt.tolist())
        all_prob.extend(prob_pos.tolist())
        all_pred.extend(pred.tolist())

    avg_loss = float(np.mean(losses)) if losses else 0.0
    metrics = compute_metrics(np.asarray(all_y), np.asarray(all_prob), np.asarray(all_pred))
    return avg_loss, metrics


# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data", type=str, default="Corpus", choices=list(DATASET_REGISTRY.keys()))
    p.add_argument("--base_dir", type=str, default="./dataset")
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--kfold", type=int, default=10)
    p.add_argument("--gpu", type=int, default=0)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--alpha", type=float, default=0.1)
    p.add_argument("--tau", type=float, default=0.5)
    p.add_argument("--thr", type=float, default=0.5)
    p.add_argument("--max_depth", type=int, default=6)
    p.add_argument("--motif_vocab", type=int, default=1000)
    p.add_argument("--wl_h", type=int, default=2)
    p.add_argument("--save_dir", type=str, default="./runs/rmgnn")
    args = p.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
    os.makedirs(args.save_dir, exist_ok=True)

    dataset_dir = os.path.join(args.base_dir, args.data)
    raw_dir = os.path.join(dataset_dir, "raw")

    # 1) AST -> HRG intermediates
    if not os.path.exists(os.path.join(dataset_dir, "A.txt")):
        logger.info("Preprocessing ASTs into HRG intermediates: %s", dataset_dir)
        ASTProcessor(raw_dir=raw_dir, out_dir=dataset_dir).process_all()
    else:
        logger.info("Found existing HRG intermediates: %s", dataset_dir)

    # 2) Load original HRGs
    nx_graphs = load_graph_structure(dataset_dir)
    if len(nx_graphs) == 0:
        raise RuntimeError("No graphs loaded from intermediate files.")

    labels = np.asarray([int(g.graph.get("label", 0)) for g in nx_graphs], dtype=np.int64)

    # 3) Motif mining with canonicalization + hashing
    miner = MotifMiner(vocab_size=args.motif_vocab, max_depth=args.max_depth)
    insts_list: List[List[MotifInstance]] = []
    for g in nx_graphs:
        _, insts = miner.extract_with_instances(g)
        insts_list.append(insts)

    logger.info("Motif stats: %s", json.dumps(miner.debug_stats(nx_graphs), ensure_ascii=False))

    # 4) Build per-sample DGL graphs
    dgl_graphs, labels = build_per_sample_dgl(nx_graphs, insts_list)

    # Model config
    d_in = int(dgl_graphs[0].ndata["feat"].shape[1])
    cfg = {
        "d_in": d_in,
        "d_h": 128,
        "n_layers": 3,
        "dropout": 0.2,
        "tau": args.tau,
        "n_heads": 4,
    }

    # 5) Stratified K-Fold training
    skf = StratifiedKFold(n_splits=args.kfold, shuffle=True, random_state=args.seed)

    fold_results: List[Dict[str, float]] = []
    for fold, (tr_idx, te_idx) in enumerate(skf.split(np.zeros_like(labels), labels), start=1):
        logger.info("Fold %d/%d", fold, args.kfold)

        tr_dgl = [dgl_graphs[i] for i in tr_idx]
        tr_y = labels[tr_idx]
        tr_nx = [nx_graphs[i] for i in tr_idx]
        tr_inst = [insts_list[i] for i in tr_idx]

        te_dgl = [dgl_graphs[i] for i in te_idx]
        te_y = labels[te_idx]
        te_nx = [nx_graphs[i] for i in te_idx]
        te_inst = [insts_list[i] for i in te_idx]

        train_ds = SampleDataset(tr_dgl, tr_y, tr_nx, tr_inst)
        test_ds = SampleDataset(te_dgl, te_y, te_nx, te_inst)

        train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
        test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

        model = RMGNN(cfg).to(device)
        opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)

        best_f1 = -1.0
        best_metrics: Optional[Dict[str, float]] = None

        for ep in range(1, args.epochs + 1):
            t0 = time.time()

            tr_loss, tr_m = run_epoch(
                model, train_loader, opt, device,
                alpha=args.alpha, tau=args.tau, thr=args.thr,
                wl_h=args.wl_h, use_edge_type=True
            )
            te_loss, te_m = run_epoch(
                model, test_loader, None, device,
                alpha=args.alpha, tau=args.tau, thr=args.thr,
                wl_h=args.wl_h, use_edge_type=True
            )

            dt = time.time() - t0
            logger.info(
                "Fold %d Ep %02d | train_loss=%.4f test_loss=%.4f | "
                "Test AP=%.4f Acc=%.4f Rec=%.4f F1=%.4f | time=%.1fs",
                fold, ep, tr_loss, te_loss,
                te_m["AP"], te_m["Acc"], te_m["Rec"], te_m["F1"], dt
            )

            if te_m["F1"] > best_f1:
                best_f1 = te_m["F1"]
                best_metrics = dict(te_m)

        assert best_metrics is not None
        fold_results.append(best_metrics)

        with open(os.path.join(args.save_dir, f"{args.data}_fold{fold}.json"), "w", encoding="utf-8") as f:
            json.dump(best_metrics, f, ensure_ascii=False, indent=2)

    # 6) Aggregate
    keys = ["AP", "Acc", "Rec", "F1"]
    avg = {k: float(np.mean([r[k] for r in fold_results])) for k in keys}
    std = {k: float(np.std([r[k] for r in fold_results])) for k in keys}

    summary = {
        "dataset": args.data,
        "engine": DATASET_REGISTRY.get(args.data, {}).get("engine", "Unknown"),
        "kfold": args.kfold,
        "avg": avg,
        "std": std,
        "folds": fold_results,
        "config": cfg,
        "train_args": vars(args),
    }

    out_path = os.path.join(args.save_dir, f"{args.data}_summary.json")
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)

    logger.info("Final Summary: %s", json.dumps({"avg": avg, "std": std}, ensure_ascii=False))
    logger.info("Saved summary to: %s", out_path)


if __name__ == "__main__":
    main()
