# train.py
import os
import random
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch

from data import load_tcga_graphs, split_graphs, get_dataloaders, to_device, ensure_edge_attr
from models import GraphEncoder, AutoGCLModel, ViewGenerator
from utils import set_seed, symmetric_node_contrast_loss, view_similarity_loss

def make_autogcl_views(gen1, gen2, graph):
    g = ensure_edge_attr(graph)
    v1, A1, keep1, map1 = gen1(g)
    v2, A2, keep2, map2 = gen2(g)

    if getattr(v1, "edge_attr", None) is not None and v1.edge_attr.dim() == 1:
        v1.edge_attr = v1.edge_attr.unsqueeze(-1)
    if getattr(v2, "edge_attr", None) is not None and v2.edge_attr.dim() == 1:
        v2.edge_attr = v2.edge_attr.unsqueeze(-1)
    return v1, v2, A1, A2, keep1, keep2, map1, map2

def _build_batch_views_and_alignment(gen1, gen2, data_list, device):
    views1, views2 = [], []
    A1_list, A2_list = [], []
    pos_idx1, pos_idx2 = [], []

    node_offset1 = 0
    node_offset2 = 0

    for graph in data_list:
        graph = to_device(graph, device)
        if graph.x.size(0) == 0:
            continue

        v1, v2, A1, A2, keep1, keep2, map1, map2 = make_autogcl_views(gen1, gen2, graph)
        v1 = to_device(v1, device)
        v2 = to_device(v2, device)

        views1.append(v1)
        views2.append(v2)
        A1_list.append(A1.to(device))
        A2_list.append(A2.to(device))

        common = (keep1 & keep2)
        if common.sum() > 0:
            orig = torch.where(common)[0]          
            i1 = map1[orig]                       
            i2 = map2[orig]                        
            valid = (i1 >= 0) & (i2 >= 0)
            if valid.any():
                pos_idx1.append(i1[valid] + node_offset1)
                pos_idx2.append(i2[valid] + node_offset2)

        node_offset1 += v1.x.size(0)
        node_offset2 += v2.x.size(0)

    if len(views1) == 0:
        return None

    b1 = Batch.from_data_list(views1).to(device)
    b2 = Batch.from_data_list(views2).to(device)
    A1_cat = torch.cat(A1_list, dim=0) if len(A1_list) else None
    A2_cat = torch.cat(A2_list, dim=0) if len(A2_list) else None

    if len(pos_idx1) == 0:
        pos1 = pos2 = None
    else:
        pos1 = torch.cat(pos_idx1, dim=0).long()
        pos2 = torch.cat(pos_idx2, dim=0).long()

    return b1, b2, A1_cat, A2_cat, pos1, pos2

def _nt_xent_from_indices(z1, z2, pos1, pos2, tau):
    if (pos1 is None) or (pos2 is None) or (pos1.numel() == 0):
        return None

    z1_sel = z1.index_select(0, pos1)
    z2_sel = z2.index_select(0, pos2)
    return symmetric_node_contrast_loss(z1_sel, z2_sel, tau)

def train_epoch(model, gen1, gen2, loader, optimizer, device, tau, lambda_sim):
    model.train(); gen1.train(); gen2.train()
    total_loss = 0.0
    total_batches = 0

    for batch in loader:
        data_list = batch.to_data_list()
        pack = _build_batch_views_and_alignment(gen1, gen2, data_list, device)
        if pack is None:
            continue
        b1, b2, A1_cat, A2_cat, pos1, pos2 = pack

        optimizer.zero_grad(set_to_none=True)

        z1 = model(b1.x, b1.edge_index, b1.edge_attr)
        z2 = model(b2.x, b2.edge_index, b2.edge_attr)

        cl_loss = _nt_xent_from_indices(z1, z2, pos1, pos2, tau)
        if cl_loss is None:
            continue  

        sim_loss = view_similarity_loss(A1_cat, A2_cat, reduce="mean")
        loss = cl_loss + lambda_sim * sim_loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_batches += 1

    avg_loss = total_loss / max(total_batches, 1)
    return avg_loss

@torch.no_grad()
def evaluate(model, gen1, gen2, loader, device, tau, lambda_sim):
    model.eval(); gen1.eval(); gen2.eval()
    total_loss = 0.0
    total_batches = 0

    for batch in loader:
        data_list = batch.to_data_list()
        pack = _build_batch_views_and_alignment(gen1, gen2, data_list, device)
        if pack is None:
            continue
        b1, b2, A1_cat, A2_cat, pos1, pos2 = pack

        z1 = model(b1.x, b1.edge_index, b1.edge_attr)
        z2 = model(b2.x, b2.edge_index, b2.edge_attr)

        cl_loss = _nt_xent_from_indices(z1, z2, pos1, pos2, tau)
        if cl_loss is None:
            continue

        sim_loss = view_similarity_loss(A1_cat, A2_cat, reduce="mean")
        loss = cl_loss + lambda_sim * sim_loss

        total_loss += loss.item()
        total_batches += 1

    avg_loss = total_loss / max(total_batches, 1)
    return avg_loss

def main():
    parser = argparse.ArgumentParser(description="AutoGCL Pretraining on GRN (TCGA) with GRACE-encoder")
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='/path/to/Pretrain/AutoGCL/checkpoints')
    parser.add_argument('--log-dir', type=str, default='/path/to/Pretrain/AutoGCL/runs')

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--num-epochs', type=int, default=2000)
    parser.add_argument('--tau', type=float, default=0.25)
    parser.add_argument('--lambda-sim', type=float, default=0.1)

    parser.add_argument('--hidden-channels', type=int, default=64)
    parser.add_argument('--out-channels', type=int, default=64)
    parser.add_argument('--num-heads', type=int, default=8)

    parser.add_argument('--proj-hidden-dim', type=int, default=64)
    parser.add_argument('--proj-out-dim', type=int, default=64)

    parser.add_argument('--gen-hidden', type=int, default=64)
    parser.add_argument('--gen-heads', type=int, default=1)
    parser.add_argument('--gumbel-tau', type=float, default=1.0)

    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    run_name = f"lr{args.learning_rate}_tau{args.tau}_proj{args.proj_out_dim}_lsim{args.lambda_sim}"
    log_dir = os.path.join(args.log_dir, run_name)
    checkpoint_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)

    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)

    in_channels = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_channels, args.hidden_channels, args.out_channels, num_heads=args.num_heads)
    model = AutoGCLModel(encoder, proj_hidden_dim=args.proj_hidden_dim, proj_out_dim=args.proj_out_dim,
                       encoder_out_dim=args.out_channels).to(device)

    gen1 = ViewGenerator(in_channels, hidden=args.gen_hidden, num_heads=args.gen_heads, edge_dim=1, tau=args.gumbel_tau).to(device)
    gen2 = ViewGenerator(in_channels, hidden=args.gen_hidden, num_heads=args.gen_heads, edge_dim=1, tau=args.gumbel_tau).to(device)

    optimizer = torch.optim.AdamW(
        list(model.parameters()) + list(gen1.parameters()) + list(gen2.parameters()),
        lr=args.learning_rate
    )

    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, gen1, gen2, train_loader, optimizer, device, args.tau, args.lambda_sim)
        test_loss  = evaluate(model, gen1, gen2, test_loader, device, args.tau, args.lambda_sim)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/test',  test_loss,  epoch)
        print(f"Epoch {epoch:04d}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")

        if epoch % 100 == 0:
            ckpt = {
                'model': model.state_dict(),
                'gen1': gen1.state_dict(),
                'gen2': gen2.state_dict(),
                'args': vars(args),
                'epoch': epoch
            }
            path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
            torch.save(ckpt, path)
            print(f"Checkpoint saved: {path}")

    writer.close()

if __name__ == '__main__':
    main()
