# train.py
import os
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch, Data

from data import load_tcga_graphs, split_graphs, get_dataloaders, to_device, ensure_edge_attr
from models import GraphEncoder, ArielModel
from utils import (
    set_seed,
    symmetric_node_contrast_loss,
    make_view_RE_MF,
    info_regularizer_cos,
    _project_l1_to_budget,
)


def _get_edge_candidates_cached(g: Data, topk_non_edges: int):
    
    if not hasattr(g, "_cand_src") or not hasattr(g, "_cand_dst") or not hasattr(g, "_E0"):
        device = g.edge_index.device
        N = g.x.size(0)
        ei = g.edge_index
        E0 = ei.size(1)
        src0, dst0 = ei[0], ei[1]

        K = min(topk_non_edges, max(0, N * 10))
        ns = torch.randint(0, N, (K,), device=device)
        nd = torch.randint(0, N, (K,), device=device)
        m = (ns != nd)
        ns, nd = ns[m], nd[m]

        g._cand_src = torch.cat([src0, ns], dim=0)
        g._cand_dst = torch.cat([dst0, nd], dim=0)
        g._E0 = E0
    return g._cand_src, g._cand_dst, g._E0


@torch.enable_grad()  
def generate_adversarial_view(
    model,
    base_view: Data,
    ref_view: Data,
    tau: float,
    pgd_steps: int = 5,
    step_x: float = 0.1,
    step_a: float = 0.5,
    delta_x: float = 0.5,
    budget_a: float = 50.0,
    topk_non_edges: int = 2000,
    thr_edge: float = 0.5,
) -> Data:

    device = base_view.x.device
    base_view = ensure_edge_attr(base_view)

    cand_src, cand_dst, E0 = _get_edge_candidates_cached(base_view, topk_non_edges)
    E_cand = cand_src.numel()

    w = torch.zeros(E_cand, device=device, requires_grad=True)
    w.data[:E0] = 1.0
    w0 = w.data.clone()

    LX = torch.zeros_like(base_view.x, device=device, requires_grad=True)

    req_grad_flags = []
    for p in model.parameters():
        req_grad_flags.append(p.requires_grad)
        p.requires_grad_(False)  

    with torch.no_grad():        
        z_ref = model(ref_view.x, ref_view.edge_index, ref_view.edge_attr).detach()

    x_min = base_view.x.min().item() - delta_x
    x_max = base_view.x.max().item() + delta_x

    for _ in range(pgd_steps):
        w_cont = w.sigmoid().unsqueeze(-1)  # (E_cand, 1)

        ei_cand = torch.stack([cand_src, cand_dst], dim=0)      # [2, E_cand]
        ei_ud   = torch.cat([ei_cand, ei_cand.flip(0)], dim=1)  # [2, 2E_cand]
        ea_ud   = torch.cat([w_cont, w_cont], dim=0)            # [2E_cand, 1]

        X_adv_cont = (base_view.x + LX).clamp(x_min, x_max)

        z_adv = model(X_adv_cont, ei_ud, ea_ud)
        loss_attack = symmetric_node_contrast_loss(z_ref, z_adv, tau)

        g_w, g_x = torch.autograd.grad(
            loss_attack, [w, LX],
            retain_graph=False, create_graph=False, allow_unused=True
        )

        if g_w is not None:
            w = w + step_a * g_w.sign()
            w = _project_l1_to_budget(w, w0, budget_a)  
            w.requires_grad_(True)
        if g_x is not None:
            LX = LX + step_x * g_x.sign()
            LX = LX.clamp(-delta_x, delta_x)            
            LX.requires_grad_(True)

    for p, f in zip(model.parameters(), req_grad_flags):
        p.requires_grad_(f)

    keep = (w.detach() > thr_edge)
    src_keep = cand_src[keep]; dst_keep = cand_dst[keep]

    m = (src_keep != dst_keep)
    src_keep, dst_keep = src_keep[m], dst_keep[m]

    ei_adv = torch.stack(
        [torch.cat([src_keep, dst_keep], dim=0),
         torch.cat([dst_keep, src_keep], dim=0)],
        dim=0
    )
    ea_adv = torch.ones((ei_adv.size(1), 1), device=device)

    X_adv = (base_view.x + LX.detach()).clamp(x_min, x_max)
    return Data(x=X_adv, edge_index=ei_adv, edge_attr=ea_adv)

def _two_views_and_adv(model, data: Data, p_re: float, p_mf: float, tau: float, pgd_cfg: dict):

    clean = ensure_edge_attr(data)
    v1 = make_view_RE_MF(clean, p_re, p_mf)
    v2 = make_view_RE_MF(clean, p_re, p_mf)

    adv = generate_adversarial_view(
        model=model,
        base_view=clean,
        ref_view=v1,
        tau=tau,
        **pgd_cfg
    )

    if getattr(v1, "edge_attr", None) is not None and v1.edge_attr.dim() == 1:
        v1.edge_attr = v1.edge_attr.unsqueeze(-1)
    if getattr(v2, "edge_attr", None) is not None and v2.edge_attr.dim() == 1:
        v2.edge_attr = v2.edge_attr.unsqueeze(-1)
    if getattr(adv, "edge_attr", None) is not None and adv.edge_attr.dim() == 1:
        adv.edge_attr = adv.edge_attr.unsqueeze(-1)

    return v1, v2, adv, clean

def train_epoch(model, loader, optimizer, device, tau, eps1, eps2, p_re, p_mf, pgd_cfg):
    model.train()
    total_loss, total_graphs = 0.0, 0

    for batch in loader:
        data_list = batch.to_data_list()
        batch_loss, num_graphs = 0.0, 0

        for g in data_list:
            g = to_device(g, device)
            if g.x.size(0) == 0:
                continue

            v1, v2, adv, clean = _two_views_and_adv(model, g, p_re, p_mf, tau, pgd_cfg)

            z1 = model(v1.x, v1.edge_index, v1.edge_attr)
            z2 = model(v2.x, v2.edge_index, v2.edge_attr)
            za = model(adv.x, adv.edge_index, adv.edge_attr)
            zh = model(clean.x, clean.edge_index, clean.edge_attr)

            l_con = symmetric_node_contrast_loss(z1, z2, tau)
            l_adv = symmetric_node_contrast_loss(z1, za, tau)
            l_inf = info_regularizer_cos(z1, z2, zh)

            loss_graph = l_con + eps1 * l_adv + eps2 * l_inf
            batch_loss += loss_graph
            num_graphs += 1

        if num_graphs == 0:
            continue

        batch_loss = batch_loss / num_graphs
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

        total_loss += batch_loss.item() * num_graphs
        total_graphs += num_graphs

    return total_loss / max(total_graphs, 1)

def evaluate(model, loader, device, tau, eps1, eps2, p_re, p_mf, pgd_cfg):

    model.eval()
    total_loss, total_graphs = 0.0, 0

    for batch in loader:
        data_list = batch.to_data_list()

        for g in data_list:
            g = to_device(g, device)
            if g.x.size(0) == 0:
                continue

            v1, v2, adv, clean = _two_views_and_adv(model, g, p_re, p_mf, tau, pgd_cfg)

            with torch.no_grad():
                z1 = model(v1.x, v1.edge_index, v1.edge_attr)
                z2 = model(v2.x, v2.edge_index, v2.edge_attr)
                za = model(adv.x, adv.edge_index, adv.edge_attr)
                zh = model(clean.x, clean.edge_index, clean.edge_attr)

                l_con = symmetric_node_contrast_loss(z1, z2, tau)
                l_adv = symmetric_node_contrast_loss(z1, za, tau)
                l_inf = info_regularizer_cos(z1, z2, zh)

                loss_graph = l_con + eps1 * l_adv + eps2 * l_inf

            total_loss += loss_graph.item()
            total_graphs += 1

    return total_loss / max(total_graphs, 1)


def main():
    parser = argparse.ArgumentParser(description="ArieL Pretraining on GRN (TCGA) with GRACE-encoder")

    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='/path/to/Pretrain/ArieL/checkpoints')
    parser.add_argument('--log-dir', type=str,
                        default='/path/to/Pretrain/ArieL/runs')

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--num-epochs', type=int, default=1000)
    parser.add_argument('--tau', type=float, default=0.5)

    parser.add_argument('--eps1', type=float, default=1.0)
    parser.add_argument('--eps2', type=float, default=0.5)

    parser.add_argument('--p-re', type=float, default=0.1)
    parser.add_argument('--p-mf', type=float, default=0.1)

    parser.add_argument('--pgd-steps', type=int, default=3)
    parser.add_argument('--pgd-step-x', type=float, default=0.1)
    parser.add_argument('--pgd-step-a', type=float, default=0.5)
    parser.add_argument('--delta-x', type=float, default=0.5)
    parser.add_argument('--budget-a', type=float, default=50.0)
    parser.add_argument('--topk-non-edges', type=int, default=500)
    parser.add_argument('--thr-edge', type=float, default=0.5)

    parser.add_argument('--hidden-channels', type=int, default=64)
    parser.add_argument('--out-channels', type=int, default=64)
    parser.add_argument('--num-heads', type=int, default=8)
    parser.add_argument('--proj-hidden-dim', type=int, default=64)
    parser.add_argument('--proj-out-dim', type=int, default=64)

    args = parser.parse_args()
    set_seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    run_name = (
        f"lr{args.learning_rate}_tau{args.tau}_eps1{args.eps1}_eps2{args.eps2}"
        f"_pgd{args.pgd_steps}x{args.pgd_step_x}a{args.pgd_step_a}_B{args.budget_a}"
    )
    log_dir = os.path.join(args.log_dir, run_name)
    ckpt_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(ckpt_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)

    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)

    in_dim = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_dim, args.hidden_channels, args.out_channels, num_heads=args.num_heads)
    model = ArielModel(encoder, args.proj_hidden_dim, args.proj_out_dim, args.out_channels).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    pgd_cfg = dict(
        pgd_steps=args.pgd_steps,
        step_x=args.pgd_step_x,
        step_a=args.pgd_step_a,
        delta_x=args.delta_x,
        budget_a=args.budget_a,
        topk_non_edges=args.topk_non_edges,
        thr_edge=args.thr_edge,
    )

    best_val = float('inf')
    best_path = os.path.join(ckpt_dir, 'best.pt')

    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, args.tau,
                                 args.eps1, args.eps2, args.p_re, args.p_mf, pgd_cfg)
        val_loss = evaluate(model, test_loader, device, args.tau,
                            args.eps1, args.eps2, args.p_re, args.p_mf, pgd_cfg)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val',   val_loss,  epoch)
        print(f"Epoch {epoch:04d}: Train {train_loss:.4f} | Val {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            ckpt = {'model': model.state_dict(), 'args': vars(args), 'epoch': epoch, 'val_loss': best_val}
            torch.save(ckpt, best_path)
            print(f"[Best] Updated at epoch {epoch} (val={best_val:.4f}) -> {best_path}")

        
        if epoch % 100 == 0:
            ckpt = {'model': model.state_dict(), 'args': vars(args), 'epoch': epoch, 'val_loss': val_loss}
            path = os.path.join(ckpt_dir, f'checkpoint_epoch_{epoch}.pt')
            torch.save(ckpt, path)
            print(f"Checkpoint saved: {path}")

    writer.close()

if __name__ == '__main__':
    main()
