import os, re, time, argparse, pickle
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import (
    mean_absolute_error, mean_squared_error,
    recall_score, f1_score, roc_auc_score, average_precision_score
)

from Denoiser_A_embedding import DenoiseNetworkA


# ---------- utils ----------
def set_seed(seed: int):
    torch.manual_seed(seed); np.random.seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def collate_graphs(batch):
    nmax = max(g.size(0) for g in batch)
    b = len(batch)
    A = torch.zeros(b, nmax, nmax, dtype=batch[0].dtype)
    nm = torch.zeros(b, nmax, dtype=torch.bool)
    for i, g in enumerate(batch):
        n = g.size(0); A[i, :n, :n] = g; nm[i, :n] = True
    return A, nm

def linear_coeffs(t: torch.Tensor):
    a = 1.0 - t; b = t
    ad = torch.full_like(t, -1.0); bd = torch.ones_like(t)
    return a, b, ad, bd

def zero_diag_(M: torch.Tensor):
    M.diagonal().zero_(); return M

def sym_zero_diag_valid(M: torch.Tensor, node_mask: torch.Tensor):
    if M.dim() == 2:
        nm = node_mask.to(M.dtype)
        pair = nm[:, None] * nm[None, :]
        M = M * pair
        ut = torch.triu(M, diagonal=1)
        M = ut + ut.T
        M.fill_diagonal_(0.0)
        M = M * pair
        return M
    B, N, _ = M.shape
    nm = node_mask.to(M.dtype)
    pair = nm.unsqueeze(2) * nm.unsqueeze(1)
    M = M * pair
    ut_mask = torch.triu(torch.ones(N, N, dtype=torch.bool, device=M.device), diagonal=1).unsqueeze(0)
    ut = M.masked_fill(~ut_mask, 0.0)
    M = ut + ut.transpose(1, 2)
    M = M.masked_fill(torch.eye(N, dtype=torch.bool, device=M.device).unsqueeze(0), 0.0)
    M = M * pair
    return M

def add_masked_symmetric_noise(M, node_mask, edge_mask, sigma: float, clip01: bool = True):
    if sigma <= 0.0:
        return sym_zero_diag_valid(M, node_mask)
    unk = (1.0 - edge_mask).to(M.dtype)
    unk = sym_zero_diag_valid(unk, node_mask)
    eps = torch.randn_like(M)
    eps = sym_zero_diag_valid(eps, node_mask)
    out = M + sigma * (eps * unk)
    out = sym_zero_diag_valid(out, node_mask)
    if clip01: out.clamp_(0.0, 1.0)
    return out

def permute_square(A: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    return A.index_select(0, p).index_select(1, p)

def invert_perm(p: torch.Tensor) -> torch.Tensor:
    inv = torch.empty_like(p); inv[p] = torch.arange(p.numel(), device=p.device); return inv

def upper_triu_mask_batched(node_mask: torch.Tensor) -> torch.Tensor:
    B, N = node_mask.shape
    ut = torch.triu(torch.ones(N, N, dtype=torch.bool, device=node_mask.device), diagonal=1).unsqueeze(0).expand(B, -1, -1)
    pair = node_mask.unsqueeze(2) & node_mask.unsqueeze(1)
    return ut & pair

def masked_upper_mse(pred, target, node_mask, edge_mask):
    B = pred.size(0)
    ut_valid = upper_triu_mask_batched(node_mask)
    masked_ut = ut_valid & (edge_mask < 0.5)
    vals = []
    for i in range(B):
        mu = masked_ut[i]
        vals.append((pred[i][mu] - target[i][mu]).pow(2).mean() if mu.any() else pred.new_tensor(0.0))
    return torch.stack(vals).mean()

def _pc1_01(Z: np.ndarray) -> np.ndarray:
    Zc = Z - Z.mean(0, keepdims=True)
    v = np.random.randn(Zc.shape[1]); v /= (np.linalg.norm(v) + 1e-12)
    for _ in range(20):
        v = Zc.T @ (Zc @ v); v /= (np.linalg.norm(v) + 1e-12)
    s = Zc @ v
    s -= s.min(); rng = s.max() - s.min()
    return (s / rng) if rng > 1e-12 else np.linspace(0, 1, Z.shape[0])

def _load_priors_from_npy_dir(npy_dir: str, masks):
    if not os.path.isdir(npy_dir): raise FileNotFoundError(npy_dir)
    files = sorted([f for f in os.listdir(npy_dir) if f.endswith(".npy")])
    if not files: raise FileNotFoundError(f"no npy in {npy_dir}")
    L = len(masks); priors, z1d = [], []
    for idx in range(L):
        picks = [f for f in files if re.search(rf'(?i)(^|[^0-9])g{idx}([^0-9]|$)', f)]
        pick = picks[0] if picks else (files[idx] if len(files) == L else None)
        if pick is None:
            raise FileNotFoundError(f"no prior for graph {idx} in {npy_dir}")
        M_mask = masks[idx].detach().cpu().numpy() if torch.is_tensor(masks[idx]) else np.array(masks[idx])
        A = np.load(os.path.join(npy_dir, pick))
        A = np.array(A, dtype=np.float32)
        if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError("prior is not square")
        if A.shape != M_mask.shape: raise ValueError("shape mismatch in prior vs mask")
        A = (A + A.T) * 0.5; np.fill_diagonal(A, 0.0); A = np.clip(A, 0.0, 1.0)
        priors.append(torch.from_numpy(A.astype(np.float32)))
        z1d.append(torch.from_numpy(_pc1_01(A).astype(np.float32)))
    return priors, z1d

def _save_grid(A_list, titles, out_png):
    k = len(A_list)
    plt.figure(figsize=(4*k, 4))
    for i, (A, t) in enumerate(zip(A_list, titles), 1):
        plt.subplot(1, k, i)
        plt.imshow(A.detach().cpu().numpy(), cmap="Greys", vmin=0.0, vmax=1.0)
        plt.title(t); plt.axis("off")
    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    plt.tight_layout(); plt.savefig(out_png, dpi=200); plt.close()

def _eval_masked_to_csv(out_csv, A_true_list, A_pred_list, M_list):
    rows = []
    for i, (A1, Ap, M) in enumerate(zip(A_true_list, A_pred_list, M_list)):
        gt = A1.detach().cpu().numpy(); pr = Ap.detach().cpu().numpy(); mk = M.detach().cpu().numpy()
        n = gt.shape[0]; iu = np.triu_indices(n, 1)
        sel = (1.0 - mk)[iu] == 1
        if not sel.any(): continue
        y = (gt[iu][sel] > 0.5).astype(int)
        s = pr[iu][sel]
        mae = mean_absolute_error(gt[iu][sel], s)
        mse = mean_squared_error(gt[iu][sel], s)
        yhat = (s > 0.5).astype(int)
        rec = recall_score(y, yhat, zero_division=0)
        f1 = f1_score(y, yhat, zero_division=0)
        try: auc = roc_auc_score(y, s)
        except: auc = float("nan")
        try: ap = average_precision_score(y, s)
        except: ap = float("nan")
        rows.append({"sample": i, "MAE": mae, "MSE": mse, "AP": ap, "ROC_AUC": auc, "Rec@0.5": rec, "F1@0.5": f1})
    if not rows:
        print("no masked edges to eval"); return
    df = pd.DataFrame(rows); avg = df.mean(numeric_only=True); avg["sample"] = "avg"
    df = pd.concat([df, avg.to_frame().T], ignore_index=True)
    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    df.to_csv(out_csv, index=False)
    print(f"metrics saved -> {out_csv}")


# ---------- train ----------
def train(args):
    print("start train")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(args.seed)

    train_graphs = pickle.load(open(args.train_pkl, "rb"))
    val_graphs   = pickle.load(open(args.val_pkl,   "rb"))

    mask_dir = os.path.join(os.path.dirname(args.train_pkl), f"masks_drop{args.drop_prob}")
    train_masks_np = pickle.load(open(os.path.join(mask_dir, "train_masks.pkl"), "rb"))
    val_masks_np   = pickle.load(open(os.path.join(mask_dir, "val_masks.pkl"),   "rb"))
    train_masks = [torch.from_numpy(m).float() for m in train_masks_np]
    val_masks   = [torch.from_numpy(m).float() for m in val_masks_np]

    pri_tr, z_tr = _load_priors_from_npy_dir(args.prior_train_dir, train_masks)
    pri_va, z_va = _load_priors_from_npy_dir(args.prior_val_dir,   val_masks)

    class DS(Dataset):
        def __init__(self, graphs, masks, priors, z1d):
            self.G, self.M, self.S, self.Z = graphs, masks, priors, z1d
        def __len__(self): return len(self.G)
        def __getitem__(self, i):
            g = self.G[i]
            A = torch.tensor(nx.to_numpy_array(g), dtype=torch.float32) if isinstance(g, nx.Graph) else g.float()
            return A, self.M[i], self.S[i].float(), self.Z[i].float()

    def col(batch):
        As, Ms, Ys, Zs = zip(*batch)
        A, nm = collate_graphs(As)
        b, n, _ = A.size()
        M = torch.zeros(b, n, n); Y = torch.zeros(b, n, n); Z = torch.zeros(b, n)
        for i, (m, y, z) in enumerate(zip(Ms, Ys, Zs)):
            k = m.size(0); M[i, :k, :k] = m; Y[i, :k, :k] = y; Z[i, :k] = z
        return A, nm, M, Y, Z

    dl_tr = DataLoader(DS(train_graphs, train_masks, pri_tr, z_tr), batch_size=args.batch, shuffle=True,  collate_fn=col)
    dl_va = DataLoader(DS(val_graphs,   val_masks,   pri_va, z_va), batch_size=args.batch, shuffle=False, collate_fn=col)

    def _sz(g):
        if isinstance(g, torch.Tensor): return g.size(0)
        if isinstance(g, nx.Graph): return g.number_of_nodes()
        raise TypeError

    nmax = max(_sz(g) for g in train_graphs)

    model = DenoiseNetworkA(
        max_feat_num=1, max_node_num=nmax,
        nhid=args.hidden, num_layers=args.layers, num_linears=args.linears,
        c_init=args.c_init, c_hid=args.c_hid, c_final=args.c_final, adim=args.hidden
    ).to(device)
    opt = optim.Adam(model.parameters(), lr=args.lr)

    out_root = os.path.join(args.output_dir, "ckpts")
    os.makedirs(out_root, exist_ok=True)
    run_dir = os.path.join(out_root, f"{args.name}_{args.drop_prob}drop_{time.strftime('%Y%m%d_%H%M%S')}")
    os.makedirs(run_dir, exist_ok=True)

    tr_hist, va_hist = [], []
    for ep in range(1, args.epochs + 1):
        model.train(); s_sum = 0.0; n_sum = 0
        for A, nm, M, Y, Z in dl_tr:
            A, nm, M, Y, Z = A.to(device), nm.to(device), M.to(device), Y.to(device), Z.to(device)
            B = A.size(0)
            # align by z
            for i in range(B):
                p = torch.argsort(Z[i])
                Z[i] = Z[i].index_select(0, p)
                nm[i] = nm[i].index_select(0, p)
                A[i] = permute_square(A[i], p)
                M[i] = permute_square(M[i], p)
                Y[i] = permute_square(Y[i], p)
            # a0 and noise
            A0 = M * A + (1.0 - M) * Y
            A0 = add_masked_symmetric_noise(A0, nm, M, sigma=args.train_noise, clip01=True)
            # interpolant
            t = torch.rand(B, device=device)
            a, b, ad, bd = linear_coeffs(t)
            av, bv = a.view(B,1,1), b.view(B,1,1)
            I = sym_zero_diag_valid(av*A0 + bv*A, nm)
            target = sym_zero_diag_valid(A - A0, nm)
            # model
            x_feat = torch.zeros_like(Z).unsqueeze(-1)
            pred = model(x_feat, I.unsqueeze(1), nm, t)
            pred = sym_zero_diag_valid(pred, nm)
            pred = pred * (1.0 - M)
            loss = masked_upper_mse(pred, target, nm, M)
            opt.zero_grad(); loss.backward(); opt.step()
            s_sum += float(loss.item()) * B; n_sum += B
        tr = s_sum / max(1, n_sum); tr_hist.append(tr)

        model.eval(); s_sum = 0.0; n_sum = 0
        with torch.no_grad():
            for A, nm, M, Y, Z in dl_va:
                A, nm, M, Y, Z = A.to(device), nm.to(device), M.to(device), Y.to(device), Z.to(device)
                B = A.size(0)
                for i in range(B):
                    p = torch.argsort(Z[i])
                    Z[i] = Z[i].index_select(0, p)
                    nm[i] = nm[i].index_select(0, p)
                    A[i] = permute_square(A[i], p)
                    M[i] = permute_square(M[i], p)
                    Y[i] = permute_square(Y[i], p)
                A0 = M * A + (1.0 - M) * Y
                A0 = add_masked_symmetric_noise(A0, nm, M, sigma=args.val_noise, clip01=True)
                t = torch.rand(B, device=device)
                a, b, _, _ = linear_coeffs(t)
                av, bv = a.view(B,1,1), b.view(B,1,1)
                I = sym_zero_diag_valid(av*A0 + bv*A, nm)
                target = sym_zero_diag_valid(A - A0, nm)
                x_feat = torch.zeros_like(Z).unsqueeze(-1)
                pred = model(x_feat, I.unsqueeze(1), nm, t)
                pred = sym_zero_diag_valid(pred, nm)
                pred = pred * (1.0 - M)
                l = masked_upper_mse(pred, target, nm, M)
                s_sum += float(l.item()) * B; n_sum += B
        va = s_sum / max(1, n_sum); va_hist.append(va)
        print(f"epoch {ep}: train {tr:.6f}  val {va:.6f}")
        if ep % args.ckpt_every == 0 or ep == args.epochs:
            path = os.path.join(run_dir, f"ep{ep:04d}.pt")
            torch.save(model.state_dict(), path)
            print(f"ckpt -> {path}")

    # plot losses
    os.makedirs(args.output_dir, exist_ok=True)
    plt.figure(figsize=(8,5))
    plt.plot(tr_hist, label="train")
    plt.plot(va_hist, label="val")
    plt.legend(); plt.xlabel("epoch"); plt.ylabel("loss"); plt.tight_layout()
    lp = os.path.join(args.output_dir, f"loss_{args.name}_{time.strftime('%Y%m%d_%H%M%S')}.png")
    plt.savefig(lp, dpi=200); plt.close()
    print(f"loss plot -> {lp}")


# ---------- sample ----------
def sample(args):
    print("start sample")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(args.seed)

    if args.sample_pkl and args.mask_pkl:
        Gs = pickle.load(open(args.sample_pkl, "rb"))
        A1_list = [
            torch.tensor(nx.to_numpy_array(g), dtype=torch.float32, device=device)
            if isinstance(g, nx.Graph) else g.to(dtype=torch.float32, device=device)
            for g in Gs
        ]
        M_list = [torch.from_numpy(m).to(device).float() for m in pickle.load(open(args.mask_pkl, "rb"))]
    elif args.input_graph and args.mask_npy:
        A = np.load(args.input_graph).astype(np.float32)
        M = np.load(args.mask_npy).astype(np.float32)
        A1_list = [torch.from_numpy(A).to(device)]
        M_list  = [torch.from_numpy(M).to(device).float()]
    else:
        raise ValueError("provide (--sample_pkl and --mask_pkl) or (--input_graph and --mask_npy)")

    priors, zlist = _load_priors_from_npy_dir(args.prior_test_dir, [m.detach().cpu() for m in M_list])

    model = DenoiseNetworkA(
        max_feat_num=1, max_node_num=args.max_nodes,
        nhid=args.hidden, num_layers=args.layers, num_linears=args.linears,
        c_init=args.c_init, c_hid=args.c_hid, c_final=args.c_final, adim=args.hidden
    ).to(device)
    model.load_state_dict(torch.load(args.ckpt, map_location=device))
    model.eval()

    steps_grid = [int(s) for s in (args.n_steps_grid.split(",") if args.n_steps_grid else [args.n_steps])]
    out_run = os.path.join(args.output_dir, f"{args.name}_{args.drop_prob}drop_{time.strftime('%Y%m%d_%H%M%S')}")
    dir_plot = os.path.join(out_run, "plots")
    dir_npy  = os.path.join(out_run, "recon_raw")
    os.makedirs(dir_plot, exist_ok=True); os.makedirs(dir_npy, exist_ok=True)

    results_by_k = {k: [] for k in steps_grid}
    A0_raw_all = []
    for i, (A1, M, S, z) in enumerate(zip(A1_list, M_list, priors, zlist)):
        z = z.to(device).float(); S = S.to(device).float()
        p = torch.argsort(z); pinv = invert_perm(p)
        A1p, Mp, Sp, zp = permute_square(A1, p), permute_square(M, p), permute_square(S, p), z.index_select(0, p)
        nm = torch.ones(A1p.size(0), dtype=torch.bool, device=device)

        A0 = Mp * A1p + (1.0 - Mp) * Sp
        A0 = add_masked_symmetric_noise(A0, nm, Mp, sigma=args.noise_std, clip01=True)
        A0u = permute_square(A0, pinv); A0_raw_all.append(A0u.detach().cpu().clone())
        np.save(os.path.join(dir_npy, f"g{i}_A0raw.npy"), A0u.detach().cpu().numpy())

        for k in steps_grid:
            A = A0.clone(); dt = 1.0 / float(k)
            x_feat = torch.zeros(1, zp.shape[0], 1, device=device, dtype=zp.dtype)
            for s in range(k):
                t = torch.full((1,), s * dt, device=device)
                with torch.no_grad():
                    b = model(x_feat, A.unsqueeze(0).unsqueeze(1), nm.unsqueeze(0), t).squeeze(0)
                    b = sym_zero_diag_valid(b, nm)
                b = b * (1.0 - Mp)
                A = A + dt * b
                A.clamp_(0.0, 1.0)
                A = Mp * A1p + (1.0 - Mp) * A
                A = sym_zero_diag_valid(A, nm)
            Au = permute_square(A, pinv); zero_diag_(Au)
            results_by_k[k].append(Au.detach().cpu().clone())
            np.save(os.path.join(dir_npy, f"g{i}_k{k}_raw.npy"), Au.detach().cpu().numpy())

        # quick grid plot for the largest k
        kmax = max(steps_grid)
        _save_grid(
            [A1.detach().cpu(), M.detach().cpu(), A0u.detach().cpu(), results_by_k[kmax][-1]],
            ["true", "mask", "a0", f"recon_k{kmax}"],
            os.path.join(dir_plot, f"g{i}.png")
        )

    # masked eval per k and for a0
    eval_dir = os.path.join(out_run, "eval"); os.makedirs(eval_dir, exist_ok=True)
    for k, lst in results_by_k.items():
        _eval_masked_to_csv(os.path.join(eval_dir, f"k{k}.csv"), A1_list, lst, M_list)
    _eval_masked_to_csv(os.path.join(eval_dir, "a0.csv"), A1_list, A0_raw_all, M_list)
    print(f"done -> {out_run}")


# ---------- cli ----------
def build_parser():
    p = argparse.ArgumentParser(description="link predictor")
    sub = p.add_subparsers(dest="cmd", required=True)

    pt = sub.add_parser("train")
    pt.add_argument("--train_pkl", required=True)
    pt.add_argument("--val_pkl",   required=True)
    pt.add_argument("--drop_prob", type=float, required=True)
    pt.add_argument("--prior_train_dir", required=True)
    pt.add_argument("--prior_val_dir",   required=True)
    pt.add_argument("--hidden", type=int, default=32)
    pt.add_argument("--layers", type=int, default=5)
    pt.add_argument("--linears", type=int, default=2)
    pt.add_argument("--c_init", type=int, default=2)
    pt.add_argument("--c_hid",  type=int, default=8)
    pt.add_argument("--c_final",type=int, default=2)
    pt.add_argument("--lr", type=float, default=2e-4)
    pt.add_argument("--batch", type=int, default=64)
    pt.add_argument("--epochs", type=int, default=1000)
    pt.add_argument("--train_noise", type=float, default=0.1)
    pt.add_argument("--val_noise",   type=float, default=0.1)
    pt.add_argument("--ckpt_every", type=int, default=100)
    pt.add_argument("--name", default="link_prediction")
    pt.add_argument("--output_dir", default="./runs")
    pt.add_argument("--seed", type=int, default=0)

    ps = sub.add_parser("sample")
    ps.add_argument("--ckpt", required=True)
    ps.add_argument("--drop_prob", type=float, required=True)
    ps.add_argument("--prior_test_dir", required=True)
    ps.add_argument("--sample_pkl", default="")
    ps.add_argument("--mask_pkl",   default="")
    ps.add_argument("--input_graph", default="")
    ps.add_argument("--mask_npy",     default="")
    ps.add_argument("--max_nodes", type=int, default=512)
    ps.add_argument("--hidden", type=int, default=32)
    ps.add_argument("--layers", type=int, default=5)
    ps.add_argument("--linears", type=int, default=2)
    ps.add_argument("--c_init", type=int, default=2)
    ps.add_argument("--c_hid",  type=int, default=8)
    ps.add_argument("--c_final",type=int, default=2)
    ps.add_argument("--noise_std", type=float, default=0.1)
    ps.add_argument("--n_steps", type=int, default=100)
    ps.add_argument("--n_steps_grid", default="")  # ex: "1,10,100"
    ps.add_argument("--name", default="sample")
    ps.add_argument("--output_dir", default="./runs")
    ps.add_argument("--seed", type=int, default=0)
    return p

def main():
    args = build_parser().parse_args()
    if args.cmd == "train":
        train(args)
    elif args.cmd == "sample":
        sample(args)

if __name__ == "__main__":
    main()
