# train.py
import os
import time
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.utils import to_dense_batch, to_dense_adj

from data import load_tcga_graphs, split_graphs, get_dataloaders, to_device
from models import Encoder, Online, Target
from utils import set_seed


def train_epoch(online_model, target_model, opt_on, opt_tgt, loader, device, global_hop):
    online_model.train()
    target_model.train()
    total_on, total_tgt, count = 0.0, 0.0, 0

    for batch in loader:
        batch = to_device(batch, device)
        # Update online model
        h_raw, h_proj = online_model.online_encoder(batch.x, batch.edge_index, batch.edge_attr)
        with torch.no_grad():
            # Get target projection only (single output)
            h_tgt_proj_online = target_model(batch.x, batch.edge_index, batch.edge_attr)
        dense_adjs = to_dense_adj(batch.edge_index, batch=batch.batch)
        Hb, mask = to_dense_batch(h_proj, batch.batch)
        Tb_online, _ = to_dense_batch(h_tgt_proj_online, batch.batch)
        
        # TCM: multi-hop diffusion
        H2 = Hb.clone()
        for _ in range(global_hop):
            H2 = torch.einsum('bij,bjk->bik', dense_adjs, H2)
        Hc = Hb + H2
        B, Nmax, D = Hc.shape
        Hp = online_model.predictor(Hc.view(-1, D)).view(B, Nmax, D)

        # Calculate online loss
        loss_on = sum(online_model.get_loss(Hp[i][mask[i]], Tb_online[i][mask[i]].detach()) for i in range(B)) / B
        opt_on.zero_grad(); loss_on.backward(); opt_on.step()

        # Update target model (RSM loss)
        # Recalculate target projection and compute loss with gradients
        h_tgt_proj = target_model(batch.x, batch.edge_index, batch.edge_attr)
        loss_tgt = target_model.get_loss(h_tgt_proj)
        opt_tgt.zero_grad(); loss_tgt.backward(); opt_tgt.step()

        total_on  += loss_on.item() * B
        total_tgt += loss_tgt.item() * B
        count     += B

    return total_on/count, total_tgt/count

@torch.no_grad()
def test_epoch(online_model, target_model, loader, device, global_hop):
    online_model.eval(); target_model.eval()
    total_on, total_tgt, count = 0.0, 0.0, 0

    for batch in loader:
        batch = to_device(batch, device)
        # Get projection
        _, h_proj = online_model.online_encoder(batch.x, batch.edge_index, batch.edge_attr)
        # Get target projection
        h_tgt_proj = target_model(batch.x, batch.edge_index, batch.edge_attr)

        dense_adjs = to_dense_adj(batch.edge_index, batch=batch.batch)
        Hb, mask = to_dense_batch(h_proj, batch.batch)
        Tb, _    = to_dense_batch(h_tgt_proj, batch.batch)

        # TCM: multi-hop diffusion
        H2 = Hb.clone()
        for _ in range(global_hop):
            H2 = torch.einsum('bij,bjk->bik', dense_adjs, H2)
        Hc = Hb + H2
        B, Nmax, D = Hc.shape
        Hp = online_model.predictor(Hc.view(-1, D)).view(B, Nmax, D)

        # Aggregate losses
        for i in range(B):
            total_on  += online_model.get_loss(Hp[i][mask[i]], Tb[i][mask[i]].detach()).item()
            total_tgt += target_model.get_loss(Tb[i][mask[i]]).item()
            count     += 1

    return total_on/count, total_tgt/count


def main():
    parser = argparse.ArgumentParser(description="SGRL for TCGA Gene Regulatory Network")
    parser.add_argument('--tcga-graphs-path', type=str, default='/path/to/tcga_graphs.pkl')
    parser.add_argument('--checkpoint-dir',   type=str, default='/path/to/checkpoints')
    parser.add_argument('--log-dir',          type=str, default='./runs')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
    parser.add_argument('--num-epochs', type=int, default=3000, help='Number of epochs')
    parser.add_argument('--lr-online', type=float, default=1e-4, help='Online learning rate')
    parser.add_argument('--lr-target', type=float, default=1e-4, help='Target learning rate')
    parser.add_argument('--global-hop', type=int, default=3, help='Number of global hops')
    parser.add_argument('--hidden-channels', type=int, default=64, help='Hidden layer dimension')
    parser.add_argument('--proj-dim', type=int, default=64, help='Projection dimension')
    parser.add_argument('--num-heads', type=int, default=8, help='Number of TransformerConv heads')
    parser.add_argument('--momentum', type=float, default=0.99, help='EMA update coefficient')
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=args.log_dir)

    # Prepare data
    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_g, test_g = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_g, test_g, args.batch_size)

    # Initialize models
    in_ch = graphs[0].x.shape[1]
    online_enc = Encoder(in_ch, args.hidden_channels, args.proj_dim, args.num_heads).to(device)
    target_enc = Encoder(in_ch, args.hidden_channels, args.proj_dim, args.num_heads).to(device)
    target_enc.load_state_dict(online_enc.state_dict())  # Synchronize initialization
    online_model = Online(online_enc, target_enc, args.proj_dim, args.momentum).to(device)
    target_model = Target(target_enc).to(device)

    opt_on = torch.optim.AdamW(online_model.parameters(), lr=args.lr_online)
    opt_tgt = torch.optim.AdamW(target_model.parameters(), lr=args.lr_target)

    print("=== Start Pretraining ===")
    best_val = float('inf')
    for epoch in range(1, args.num_epochs+1):
        tr_on, tr_tgt = train_epoch(online_model, target_model, opt_on, opt_tgt,
                                    train_loader, device, args.global_hop)
        te_on, te_tgt = test_epoch(online_model, target_model,
                                    test_loader,  device, args.global_hop)
        writer.add_scalars('OnlineLoss', {'train': tr_on, 'test': te_on}, epoch)
        writer.add_scalars('TargetLoss', {'train': tr_tgt, 'test': te_tgt}, epoch)
        print(f"Epoch {epoch} | tr_on {tr_on:.4f} te_on {te_on:.4f} | tr_tgt {tr_tgt:.4f} te_tgt {te_tgt:.4f}")

        # Update EMA
        online_model.update_target_encoder()

        # Save checkpoint every 100 epochs
        if epoch % 100 == 0:
            torch.save(online_model.state_dict(), os.path.join(args.checkpoint_dir, f'checkpoint_online_epoch_{epoch}.pt'))
            torch.save(target_model.state_dict(),  os.path.join(args.checkpoint_dir, f'checkpoint_target_epoch_{epoch}.pt'))

        # Save best model
        if te_on < best_val:
            best_val = te_on
            torch.save(online_model.state_dict(), os.path.join(args.checkpoint_dir, 'best_online.pt'))
            torch.save(target_model.state_dict(),  os.path.join(args.checkpoint_dir, 'best_target.pt'))

    writer.close()
    print("Pretraining finished.")

if __name__ == '__main__':
    main()