import os
import argparse
import time
import pickle
import math
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import networkx as nx
import matplotlib.pyplot as plt


from Denoiser_A_embedding import DenoiseNetworkA


# ---------- utils used by train_fake and sample_fake ----------

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def collate_graphs(batch):
    # pad to same n
    max_n = max(A.size(0) for A in batch)
    B = len(batch)
    pad = torch.zeros(B, max_n, max_n, dtype=batch[0].dtype)
    node_mask = torch.zeros(B, max_n, dtype=torch.bool)
    for i, A in enumerate(batch):
        n = A.size(0)
        pad[i, :n, :n] = A
        node_mask[i, :n] = True
    return pad, node_mask

def zero_diag_(M: torch.Tensor) -> torch.Tensor:
    M.diagonal().zero_()
    return M

def sym_zero_diag_valid(M: torch.Tensor, node_mask: torch.Tensor) -> torch.Tensor:
    # keep square, sym, zero diag, valid nodes only
    if M.dim() == 2:
        nm = node_mask.to(M.dtype)
        pair = nm[:, None] * nm[None, :]
        M = M * pair
        ut = torch.triu(M, diagonal=1)
        M = ut + ut.T
        M.fill_diagonal_(0.0)
        M = M * pair
        return M
    B, N, _ = M.shape
    nm = node_mask.to(M.dtype)
    pair = nm.unsqueeze(2) * nm.unsqueeze(1)
    M = M * pair
    ut_mask = torch.triu(torch.ones(N, N, dtype=torch.bool, device=M.device), diagonal=1).unsqueeze(0)
    ut = M.masked_fill(~ut_mask, 0.0)
    M = ut + ut.transpose(1, 2)
    M = M.masked_fill(torch.eye(N, dtype=torch.bool, device=M.device).unsqueeze(0), 0.0)
    M = M * pair
    return M

def add_masked_symmetric_noise(M: torch.Tensor,
                               node_mask: torch.Tensor,
                               edge_mask: torch.Tensor,
                               sigma: float,
                               clip01: bool = True) -> torch.Tensor:
    # noise on unknown region only (unknown = 1 - edge_mask)
    if sigma <= 0.0:
        return sym_zero_diag_valid(M, node_mask)
    unknown = (1.0 - edge_mask).to(M.dtype)
    unknown = sym_zero_diag_valid(unknown, node_mask)
    eps = torch.randn_like(M)
    eps = sym_zero_diag_valid(eps, node_mask)
    out = M + sigma * (eps * unknown)
    out = sym_zero_diag_valid(out, node_mask)
    if clip01:
        out.clamp_(0.0, 1.0)
    return out

def linear_coeffs(t: torch.Tensor):
    # simple linear path
    alpha = 1.0 - t
    beta = t
    alpha_dot = torch.full_like(t, -1.0)
    beta_dot = torch.ones_like(t)
    return alpha, beta, alpha_dot, beta_dot

def upper_triu_mask_batched(node_mask: torch.Tensor) -> torch.Tensor:
    # valid upper pairs
    B, N = node_mask.shape
    ut = torch.triu(torch.ones(N, N, dtype=torch.bool, device=node_mask.device), diagonal=1).unsqueeze(0)
    pair = node_mask.unsqueeze(2) & node_mask.unsqueeze(1)
    return ut & pair

def masked_upper_mse(pred: torch.Tensor,
                     target: torch.Tensor,
                     node_mask: torch.Tensor,
                     edge_mask: torch.Tensor) -> torch.Tensor:
    # mse on unknown upper pairs, mean over graphs
    B = pred.size(0)
    ut_valid  = upper_triu_mask_batched(node_mask)
    masked_ut = ut_valid & (edge_mask < 0.5)
    vals = []
    for i in range(B):
        mu = masked_ut[i]
        vals.append((pred[i][mu] - target[i][mu]).pow(2).mean() if mu.any() else pred.new_tensor(0.0))
    return torch.stack(vals).mean()

def _to_adj_tensor(x):
    # nx or tensor
    if isinstance(x, nx.Graph):
        return torch.tensor(nx.to_numpy_array(x), dtype=torch.float32)
    if torch.is_tensor(x):
        return x.float()
    raise TypeError("unsupported graph type")

def _size_of(x):
    return x.size(0) if torch.is_tensor(x) else x.number_of_nodes()

def _sanitize_R_for_A(R: torch.Tensor, A: torch.Tensor, node_mask: torch.Tensor):
    # keep on zeros of A, sym, zero diag, valid nodes
    R = torch.clamp(R, 0.0, 1.0)
    R = sym_zero_diag_valid(R, node_mask)
    R = R * (1.0 - (A > 0).float())
    return R

def _save_six_panel_fake(A_true: torch.Tensor,
                         R: torch.Tensor,
                         A_final: torch.Tensor,
                         outpath: str,
                         title_note: str = ""):
    # quick view for sanity
    A1 = A_true.detach().cpu().numpy().astype(np.float32)
    Rm = R.detach().cpu().numpy().astype(np.float32)
    Af = A_final.detach().cpu().numpy().astype(np.float32)
    Aobs = np.clip(A1 + Rm, 0.0, 1.0)
    A1_outsideR = A1 * (1.0 - Rm)
    diff_raw_on_R = (Af - A1) * Rm
    vmax = float(np.abs(diff_raw_on_R).max() or 1e-6)

    fig, ax = plt.subplots(2, 3, figsize=(14, 8))
    g = dict(cmap="Greys", vmin=0.0, vmax=1.0, interpolation="nearest")
    ax[0,0].imshow(A1, **g);             ax[0,0].set_title("true");         ax[0,0].axis("off")
    ax[0,1].imshow(Rm, **g);             ax[0,1].set_title("fake mask");    ax[0,1].axis("off")
    ax[0,2].imshow(Aobs, **g);           ax[0,2].set_title("a_obs");        ax[0,2].axis("off")
    ax[1,0].imshow(A1_outsideR, **g);    ax[1,0].set_title("true outside r"); ax[1,0].axis("off")
    ax[1,1].imshow(Af, **g);             ax[1,1].set_title(f"recon {title_note}"); ax[1,1].axis("off")
    im = ax[1,2].imshow(diff_raw_on_R, cmap="bwr", vmin=-vmax, vmax=+vmax, interpolation="nearest")
    ax[1,2].set_title("raw delta on r"); ax[1,2].axis("off")
    fig.colorbar(im, ax=ax[1,2], fraction=0.046, pad=0.04)
    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
    plt.close(fig)


# --------------------------
# train_fake
# --------------------------
def train_fake(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(args.seed)

    # load data
    train_graphs = pickle.load(open(args.train_pkl, "rb"))
    val_graphs   = pickle.load(open(args.val_pkl,   "rb"))
    train_R_np   = pickle.load(open(args.train_fake_mask_pkl, "rb"))
    val_R_np     = pickle.load(open(args.val_fake_mask_pkl,   "rb"))

    class GraphFakeDS(Dataset):
        def __init__(self, graphs, R_list):
            self.graphs = graphs
            self.R_list = R_list
        def __len__(self): return len(self.graphs)
        def __getitem__(self, i):
            A = _to_adj_tensor(self.graphs[i])
            R = torch.from_numpy(self.R_list[i]).float()
            return A, R

    def collate_fn(batch):
        As, Rs = zip(*batch)
        A_pad, node_mask = collate_graphs(As)
        max_n = A_pad.size(1); B = len(Rs)
        R_pad = torch.zeros(B, max_n, max_n, dtype=torch.float32)
        for i, R in enumerate(Rs):
            n = R.size(0)
            R_pad[i, :n, :n] = R
        return A_pad, node_mask, R_pad

    train_loader = DataLoader(GraphFakeDS(train_graphs, train_R_np),
                              batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader   = DataLoader(GraphFakeDS(val_graphs,   val_R_np),
                              batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

    Nmax = max(_size_of(g) for g in train_graphs)
    model = DenoiseNetworkA(
        max_feat_num=1, max_node_num=Nmax,
        nhid=args.hidden_dim, num_layers=args.num_layers, num_linears=args.num_linears,
        c_init=args.c_init, c_hid=args.c_hid, c_final=args.c_final, adim=args.hidden_dim
    ).to(device)
    opt = optim.Adam(model.parameters(), lr=args.lr)

    # ckpt dir
    ts = time.strftime("%Y%m%d_%H%M%S")
    ckpt_dir = os.path.join(args.models_root, f"{args.name}_{ts}")
    os.makedirs(ckpt_dir, exist_ok=True)

    for ep in range(1, args.epochs + 1):
        # train
        model.train()
        ep_sum, ep_cnt = 0.0, 0
        for A_pad, node_mask, R_pad in train_loader:
            A_pad   = A_pad.to(device)
            node_mask = node_mask.to(device)
            R_pad   = R_pad.to(device)
            B, N, _ = A_pad.size()

            A0_list, Om_list, Aobs_list = [], [], []
            for i in range(B):
                R_i = _sanitize_R_for_A(R_pad[i], A_pad[i], node_mask[i])
                A_obs = sym_zero_diag_valid(torch.clamp(A_pad[i] + R_i, 0.0, 1.0), node_mask[i])
                Om = A_obs  # update region is ones of a_obs
                # gaussian on region (noise on unknown = Om, so pass edge_mask = 1 - Om)
                A0_i = add_masked_symmetric_noise(
                    M=torch.zeros_like(A_obs), node_mask=node_mask[i],
                    edge_mask=(1.0 - Om), sigma=args.train_noise_std, clip01=True
                )
                A0_list.append(A0_i); Om_list.append(Om); Aobs_list.append(A_obs)

            A0   = torch.stack(A0_list, dim=0)
            Om   = torch.stack(Om_list,  dim=0)
            Aobs = torch.stack(Aobs_list,dim=0)

            t = torch.rand(B, device=device, dtype=A_pad.dtype)
            alpha, beta, _, _ = linear_coeffs(t)
            I_t    = sym_zero_diag_valid(alpha.view(B,1,1)*A0 + beta.view(B,1,1)*A_pad, node_mask)
            target = sym_zero_diag_valid(A_pad - A0, node_mask)

            x_feat = torch.zeros(B, N, 1, device=device, dtype=A_pad.dtype)
            pred = model(x_feat, I_t.unsqueeze(1), node_mask, t)
            pred = sym_zero_diag_valid(pred, node_mask)
            pred = pred * Om
            loss = masked_upper_mse(pred, target, node_mask, edge_mask=(1.0 - Aobs))

            opt.zero_grad()
            loss.backward()
            opt.step()

            ep_sum += float(loss.item()) * B
            ep_cnt += B

        train_loss = ep_sum / max(1, ep_cnt)

        # val
        model.eval()
        v_sum, v_cnt = 0.0, 0
        with torch.no_grad():
            for A_pad, node_mask, R_pad in val_loader:
                A_pad   = A_pad.to(device)
                node_mask = node_mask.to(device)
                R_pad   = R_pad.to(device)
                B, N, _ = A_pad.size()

                A0_list, Om_list, Aobs_list = [], [], []
                for i in range(B):
                    R_i  = _sanitize_R_for_A(R_pad[i], A_pad[i], node_mask[i])
                    Aobs = sym_zero_diag_valid(torch.clamp(A_pad[i] + R_i, 0.0, 1.0), node_mask[i])
                    Om   = Aobs
                    A0_i = add_masked_symmetric_noise(
                        M=torch.zeros_like(Aobs), node_mask=node_mask[i],
                        edge_mask=(1.0 - Om), sigma=args.val_noise_std, clip01=True
                    )
                    A0_list.append(A0_i); Om_list.append(Om); Aobs_list.append(Aobs)

                A0   = torch.stack(A0_list, dim=0)
                Om   = torch.stack(Om_list,  dim=0)
                Aobs = torch.stack(Aobs_list,dim=0)

                t = torch.rand(B, device=device, dtype=A_pad.dtype)
                alpha, beta, _, _ = linear_coeffs(t)
                I_t    = sym_zero_diag_valid(alpha.view(B,1,1)*A0 + beta.view(B,1,1)*A_pad, node_mask)
                target = sym_zero_diag_valid(A_pad - A0, node_mask)

                x_feat = torch.zeros(B, N, 1, device=device, dtype=A_pad.dtype)
                pred = model(x_feat, I_t.unsqueeze(1), node_mask, t)
                pred = sym_zero_diag_valid(pred, node_mask)
                pred = pred * Om

                l = masked_upper_mse(pred, target, node_mask, edge_mask=(1.0 - Aobs))
                v_sum += float(l.item()) * B
                v_cnt += B

        val_loss = v_sum / max(1, v_cnt)
        print(f"epoch {ep}: train={train_loss:.6f} val={val_loss:.6f}", flush=True)

        if (ep % args.ckpt_every == 0) or (ep == args.epochs):
            path = os.path.join(ckpt_dir, f"ep{ep:04d}.pt")
            torch.save(model.state_dict(), path)
            print(f"saved ckpt: {path}", flush=True)


# --------------------------
# sample_fake
# --------------------------
def sample_fake(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ts = time.strftime("%Y%m%d_%H%M%S")
    out_dir = os.path.join(args.out_root, f"{args.name}_{ts}")
    plots_dir = os.path.join(out_dir, "plots")
    a0_dir    = os.path.join(out_dir, "a0_raw")
    recon_dir = os.path.join(out_dir, "recon_raw")
    for d in (plots_dir, a0_dir, recon_dir): os.makedirs(d, exist_ok=True)

    # load inputs
    if args.sample_pkl and args.fake_mask_pkl:
        A_list_obj = pickle.load(open(args.sample_pkl, "rb"))
        R_list_np  = pickle.load(open(args.fake_mask_pkl, "rb"))
        A_list = [ _to_adj_tensor(g).to(device) for g in A_list_obj ]
        R_list = [ torch.from_numpy(r).float().to(device) for r in R_list_np ]
    elif args.input_graph and args.fake_mask_npy:
        A = torch.from_numpy(np.load(args.input_graph).astype(np.float32)).to(device)
        R = torch.from_numpy(np.load(args.fake_mask_npy).astype(np.float32)).to(device)
        A_list, R_list = [A], [R]
    else:
        raise ValueError("provide (--sample_pkl and --fake_mask_pkl) or (--input_graph and --fake_mask_npy)")

    # model
    maxN = max(A.size(0) for A in A_list)
    model = DenoiseNetworkA(
        max_feat_num=1, max_node_num=maxN,
        nhid=args.hidden_dim, num_layers=args.num_layers, num_linears=args.num_linears,
        c_init=args.c_init, c_hid=args.c_hid, c_final=args.c_final, adim=args.hidden_dim
    ).to(device)
    model.load_state_dict(torch.load(args.ckpt, map_location=device))
    model.eval()

    # steps grid
    steps = [int(x) for x in (args.sample_steps.split(",") if args.sample_steps else "1,2,5,10,20,50,100").split(",")]
    steps = [s for s in steps if s > 0]

    for i, (A1, R) in enumerate(zip(A_list, R_list)):
        N = A1.size(0)
        node_mask = torch.ones(N, dtype=torch.bool, device=device)

        R = _sanitize_R_for_A(R, A1, node_mask)
        A_obs = sym_zero_diag_valid(torch.clamp(A1 + R, 0.0, 1.0), node_mask)
        Om = A_obs

        # a0 raw: gaussian on region (noise on unknown = Om, pass edge_mask = 1 - Om)
        A0 = add_masked_symmetric_noise(M=torch.zeros_like(A_obs),
                                        node_mask=node_mask,
                                        edge_mask=(1.0 - Om),
                                        sigma=args.noise_std,
                                        clip01=True)
        np.save(os.path.join(a0_dir, f"g{i}_a0_raw.npy"), A0.detach().cpu().numpy())

        # one figure for the max step
        max_step_recon = None
        for s in steps:
            A = A0.clone()
            dt = 1.0 / float(s)
            x_feat = torch.zeros(1, N, 1, device=device, dtype=A.dtype)
            for t_idx in range(s):
                t = torch.full((1,), t_idx * dt, device=device, dtype=A.dtype)
                with torch.no_grad():
                    drift = model(x_feat, A.unsqueeze(0).unsqueeze(1), node_mask.unsqueeze(0), t).squeeze(0)
                    drift = sym_zero_diag_valid(drift, node_mask)
                drift = drift * Om
                A = A + dt * drift
                A.clamp_(0.0, 1.0)
                A = A_obs + Om * A
                A = sym_zero_diag_valid(A, node_mask)
            zero_diag_(A)
            np.save(os.path.join(recon_dir, f"g{i}_k{s}_raw.npy"), A.detach().cpu().numpy())
            if s == max(steps): max_step_recon = A.clone()

        # small plot
        if max_step_recon is not None:
            _save_six_panel_fake(A_true=A1, R=R, A_final=max_step_recon,
                                 outpath=os.path.join(plots_dir, f"g{i}_panel.png"),
                                 title_note=f"k{max(steps)}")


# --------------------------
# cli
# --------------------------
def build_parser():
    p = argparse.ArgumentParser(description="fake-edge training and sampling")
    sub = p.add_subparsers(dest="cmd", required=True)

    # shared model args
    def add_model_args(ap):
        ap.add_argument("--hidden_dim", type=int, default=32)
        ap.add_argument("--num_layers", type=int, default=5)
        ap.add_argument("--num_linears", type=int, default=2)
        ap.add_argument("--c_init", type=int, default=2)
        ap.add_argument("--c_hid", type=int, default=8)
        ap.add_argument("--c_final", type=int, default=2)

    # train_fake
    ap_tr = sub.add_parser("train_fake")
    ap_tr.add_argument("--train_pkl", required=True)
    ap_tr.add_argument("--val_pkl",   required=True)
    ap_tr.add_argument("--train_fake_mask_pkl", required=True)
    ap_tr.add_argument("--val_fake_mask_pkl",   required=True)
    ap_tr.add_argument("--batch_size", type=int, default=8)
    ap_tr.add_argument("--epochs", type=int, default=1000)
    ap_tr.add_argument("--lr", type=float, default=2e-4)
    ap_tr.add_argument("--train_noise_std", type=float, default=0.1)
    ap_tr.add_argument("--val_noise_std",   type=float, default=0.1)
    ap_tr.add_argument("--ckpt_every", type=int, default=100)
    ap_tr.add_argument("--seed", type=int, default=0)
    ap_tr.add_argument("--name", type=str, default="fake_edge")
    ap_tr.add_argument("--models_root", type=str, default="./models_fake")
    add_model_args(ap_tr)

    # sample_fake
    ap_sa = sub.add_parser("sample_fake")
    ap_sa.add_argument("--ckpt", required=True)
    g = ap_sa.add_mutually_exclusive_group(required=True)
    g.add_argument("--sample_pkl")
    g.add_argument("--input_graph")
    ap_sa.add_argument("--fake_mask_pkl")
    ap_sa.add_argument("--fake_mask_npy")
    ap_sa.add_argument("--sample_steps", type=str, default="1,2,5,10,20,50,100")
    ap_sa.add_argument("--noise_std", type=float, default=0.1)
    ap_sa.add_argument("--name", type=str, default="fake_edge_sample")
    ap_sa.add_argument("--out_root", type=str, default="./out_fake")
    ap_sa.add_argument("--seed", type=int, default=0)
    add_model_args(ap_sa)

    return p

def main():
    args = build_parser().parse_args()
    if args.cmd == "train_fake":
        train_fake(args)
    elif args.cmd == "sample_fake":
        set_seed(args.seed)
        sample_fake(args)

if __name__ == "__main__":
    main()
