# train_adgcl.py
import os
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch

from data import load_tcga_graphs, split_graphs, get_dataloaders, ensure_edge_attr, to_device
from models import GraphEncoder, ADGCLModel, EdgeDropAugmenter
from utils import set_seed, symmetric_node_contrast_loss_stable, apply_edge_keep_mask


def build_batch_views_and_loss(model, augmenter, data_list, device, tau, aug_temp,
                               require_aug_grad: bool):


    graphs = [to_device(ensure_edge_attr(g), device) for g in data_list]

    batch_v1 = Batch.from_data_list(graphs)

    p_drop_list, keep_list = [], []
    if require_aug_grad:
        for g in graphs:
            p_drop, keep = augmenter(g.x, g.edge_index, g.edge_attr, temperature=aug_temp)
            p_drop_list.append(p_drop); keep_list.append(keep)
    else:
        with torch.no_grad():
            for g in graphs:
                p_drop, keep = augmenter(g.x, g.edge_index, g.edge_attr, temperature=aug_temp)
                p_drop_list.append(p_drop); keep_list.append(keep)

    graphs_v2 = []
    for g, keep in zip(graphs, keep_list):
        edge_attr_v2 = apply_edge_keep_mask(g.edge_attr, keep)
        graphs_v2.append(type(g)(x=g.x, edge_index=g.edge_index, edge_attr=edge_attr_v2))
    batch_v2 = Batch.from_data_list(graphs_v2)

    z1 = model(batch_v1.x, batch_v1.edge_index, batch_v1.edge_attr)
    z2 = model(batch_v2.x, batch_v2.edge_index, batch_v2.edge_attr)

    loss = symmetric_node_contrast_loss_stable(z1, z2, tau=tau)

    reg = torch.stack([pd.mean() for pd in p_drop_list]).mean()

    return loss, reg


def train_epoch_adgcl(model, augmenter, loader, opt_enc, opt_aug, device, tau, aug_temp, lambda_reg,
                      aug_steps=1, enc_steps=1):
    model.train(); augmenter.train()
    total_loss, total_reg, n_batches = 0.0, 0.0, 0

    for batch in loader:
        data_list = batch.to_data_list()

        # ---- Augmenter update ----
        for p in model.parameters():
            p.requires_grad_(False)
        for p in augmenter.parameters():
            p.requires_grad_(True)

        for _ in range(aug_steps):
            opt_aug.zero_grad(set_to_none=True)
            loss, reg = build_batch_views_and_loss(
                model, augmenter, data_list, device, tau, aug_temp,
                require_aug_grad=True
            )
            aug_obj = loss - lambda_reg * reg  
            (-aug_obj).backward()
            opt_aug.step()

        # ---- Encoder update ----
        for p in model.parameters():
            p.requires_grad_(True)
        for p in augmenter.parameters():
            p.requires_grad_(False)

        for _ in range(enc_steps):
            opt_enc.zero_grad(set_to_none=True)
            loss, reg = build_batch_views_and_loss(
                model, augmenter, data_list, device, tau, aug_temp,
                require_aug_grad=False  
            )
            loss.backward()
            opt_enc.step()

            total_loss += float(loss.detach().cpu())
            total_reg  += float(reg.detach().cpu())
            n_batches  += 1

    avg_loss = total_loss / max(n_batches, 1)
    avg_reg  = total_reg  / max(n_batches, 1)
    return avg_loss, avg_reg


@torch.no_grad()
def evaluate_adgcl(model, augmenter, loader, device, tau, aug_temp, lambda_reg):

    model.eval(); augmenter.eval()
    total_loss, total_reg, n_batches = 0.0, 0.0, 0

    for batch in loader:
        data_list = batch.to_data_list()
        loss, reg = build_batch_views_and_loss(
            model, augmenter, data_list, device, tau, aug_temp,
            require_aug_grad=False
        )
        total_loss += float(loss.detach().cpu())
        total_reg  += float(reg.detach().cpu())
        n_batches  += 1

    avg_loss = total_loss / max(n_batches, 1)
    avg_reg  = total_reg  / max(n_batches, 1)
    return avg_loss, avg_reg


def main():
    parser = argparse.ArgumentParser(description="AD-GCL pretraining on TCGA GRN")
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='/path/to/Pretrain/ADGCL/checkpoints')
    parser.add_argument('--log-dir', type=str,
                        default='/path/to/Pretrain/ADGCL/runs')

    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--num-epochs', type=int, default=2000)
    parser.add_argument('--tau', type=float, default=0.25)            
    parser.add_argument('--aug-temp', type=float, default=0.5)         
    parser.add_argument('--lambda-reg', type=float, default=0.1)      
    parser.add_argument('--aug-steps', type=int, default=1)            
    parser.add_argument('--enc-steps', type=int, default=1)
    parser.add_argument('--save-every', type=int, default=100)

    parser.add_argument('--hidden-channels', type=int, default=64)
    parser.add_argument('--out-channels', type=int, default=64)
    parser.add_argument('--num-heads', type=int, default=8)
    parser.add_argument('--proj-hidden-dim', type=int, default=64)
    parser.add_argument('--proj-out-dim', type=int, default=64)

    parser.add_argument('--aug-hidden', type=int, default=64)
    parser.add_argument('--aug-out', type=int, default=64)
    parser.add_argument('--aug-heads', type=int, default=4)
    parser.add_argument('--aug-mlp-hidden', type=int, default=64)

    args = parser.parse_args()
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    run_name = f"lr{args.learning_rate}_tau{args.tau}_lam{args.lambda_reg}_t{args.aug_temp}_proj{args.proj_out_dim}"
    log_dir = os.path.join(args.log_dir, run_name)
    ckpt_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(ckpt_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)

    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)

    in_channels = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_channels, args.hidden_channels, args.out_channels, num_heads=args.num_heads)
    model = ADGCLModel(
        encoder,
        proj_hidden_dim=args.proj_hidden_dim,
        proj_out_dim=args.proj_out_dim,
        encoder_out_dim=args.out_channels
    ).to(device)

    augmenter = EdgeDropAugmenter(
        in_channels,
        hidden_channels=args.aug_hidden,
        out_channels=args.aug_out,
        num_heads=args.aug_heads,
        mlp_hidden=args.aug_mlp_hidden
    ).to(device)

    opt_enc = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    opt_aug = torch.optim.AdamW(augmenter.parameters(), lr=args.learning_rate)

    best_val = float('inf')
    for epoch in range(1, args.num_epochs + 1):
        train_loss, train_reg = train_epoch_adgcl(
            model, augmenter, train_loader, opt_enc, opt_aug, device,
            tau=args.tau, aug_temp=args.aug_temp, lambda_reg=args.lambda_reg,
            aug_steps=args.aug_steps, enc_steps=args.enc_steps
        )
        val_loss, val_reg = evaluate_adgcl(
            model, augmenter, test_loader, device,
            tau=args.tau, aug_temp=args.aug_temp, lambda_reg=args.lambda_reg
        )

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('RegDrop/train_mean', train_reg, epoch)
        writer.add_scalar('RegDrop/val_mean', val_reg, epoch)
        print(f"Epoch {epoch:04d} | train {train_loss:.4f} (reg {train_reg:.3f})  "
              f"| val {val_loss:.4f} (reg {val_reg:.3f})")

        if epoch % args.save_every == 0:
            path = os.path.join(ckpt_dir, f'checkpoint_epoch_{epoch:04d}.pt')
            torch.save({
                'model': model.state_dict(),
                'augmenter': augmenter.state_dict(),
                'epoch': epoch
            }, path)
            print(f"Saved periodic checkpoint: {path}")

        if val_loss < best_val:
            best_val = val_loss
            path = os.path.join(ckpt_dir, f'best_epoch_{epoch:04d}.pt')
            torch.save({
                'model': model.state_dict(),
                'augmenter': augmenter.state_dict(),
                'epoch': epoch
            }, path)
            print(f"Saved checkpoint: {path}")

    writer.close()

if __name__ == '__main__':
    main()
