# train.py
import os
import random
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch

from data import load_tcga_graphs, split_graphs, get_dataloaders, apply_virtual_knockdown, to_device, get_graph_embedding
from models import GraphEncoder, GraphCLModel
from utils import set_seed, nt_xent_loss

def train_epoch(model, loader, optimizer, device, temperature):
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch in loader:
        data_list = batch.to_data_list()
        augmented_list1 = []
        augmented_list2 = []

        for graph in data_list:
            num_nodes = graph.x.size(0)
            if num_nodes <= 0:
                continue
            idx1 = random.randint(0, num_nodes - 1)
            idx2 = random.randint(0, num_nodes - 1)
            if num_nodes > 1:
                while idx2 == idx1:
                    idx2 = random.randint(0, num_nodes - 1)
            view1 = apply_virtual_knockdown(graph, idx1)
            view2 = apply_virtual_knockdown(graph, idx2)
            augmented_list1.append(view1)
            augmented_list2.append(view2)

        if len(augmented_list1) == 0:
            continue

        batch_view1 = to_device(Batch.from_data_list(augmented_list1), device)
        batch_view2 = to_device(Batch.from_data_list(augmented_list2), device)

        z1_node = model(batch_view1.x, batch_view1.edge_index, batch_view1.edge_attr)
        graph_emb1 = get_graph_embedding(z1_node, batch_view1.batch)
        z2_node = model(batch_view2.x, batch_view2.edge_index, batch_view2.edge_attr)
        graph_emb2 = get_graph_embedding(z2_node, batch_view2.batch)

        loss = nt_xent_loss(graph_emb1, graph_emb2, temperature)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

def evaluate(model, loader, device, temperature):
    model.eval()
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch in loader:
            data_list = batch.to_data_list()
            augmented_list1 = []
            augmented_list2 = []
            for graph in data_list:
                num_nodes = graph.x.size(0)
                if num_nodes <= 0:
                    continue
                idx1 = random.randint(0, num_nodes - 1)
                idx2 = random.randint(0, num_nodes - 1)
                if num_nodes > 1:
                    while idx2 == idx1:
                        idx2 = random.randint(0, num_nodes - 1)
                view1 = apply_virtual_knockdown(graph, idx1)
                view2 = apply_virtual_knockdown(graph, idx2)
                augmented_list1.append(view1)
                augmented_list2.append(view2)

            if len(augmented_list1) == 0:
                continue

            batch_view1 = to_device(Batch.from_data_list(augmented_list1), device)
            batch_view2 = to_device(Batch.from_data_list(augmented_list2), device)

            z1_node = model(batch_view1.x, batch_view1.edge_index, batch_view1.edge_attr)
            graph_emb1 = get_graph_embedding(z1_node, batch_view1.batch)
            z2_node = model(batch_view2.x, batch_view2.edge_index, batch_view2.edge_attr)
            graph_emb2 = get_graph_embedding(z2_node, batch_view2.batch)

            loss = nt_xent_loss(graph_emb1, graph_emb2, temperature)
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

def main():
    parser = argparse.ArgumentParser(description="GraphCL for TCGA Gene Regulatory Network Representation Learning")
    # Input/Output paths
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='path/to/tcga_graphs.pkl',
                        help='Path to TCGA graph data pickle file')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='path/to/checkpoints',
                        help='Checkpoint save directory')
    parser.add_argument('--log-dir', type=str, default='./runs',
                        help='TensorBoard log directory')
    # Hyperparameters
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--batch-size', type=int, default=8, help='Batch size')
    parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--num-epochs', type=int, default=3000, help='Number of epochs')
    parser.add_argument('--temperature', type=float, default=0.25, help='Temperature parameter for NT-Xent Loss')
    # Model parameters
    parser.add_argument('--hidden-dim', type=int, default=64, help='GraphEncoder hidden dimension')
    parser.add_argument('--out-dim', type=int, default=64, help='GraphEncoder output dimension')
    parser.add_argument('--proj-hidden-dim', type=int, default=64, help='Projection head intermediate dimension')
    parser.add_argument('--proj-out-dim', type=int, default=64, help='Projection head output dimension')
    parser.add_argument('--num-heads', type=int, default=8, help='Number of TransformerConv heads')

    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Reflect hyperparameters in directory name
    run_name = f"lr{args.learning_rate}_temp{args.temperature}_projout{args.proj_out_dim}"
    log_dir = os.path.join(args.log_dir, run_name)
    checkpoint_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(checkpoint_dir, exist_ok=True)

    writer = SummaryWriter(log_dir=log_dir)

    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)

    in_channels = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_channels, args.hidden_dim, args.out_dim, num_heads=args.num_heads)
    model = GraphCLModel(encoder, proj_hidden_dim=args.proj_hidden_dim, proj_out_dim=args.proj_out_dim).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, args.temperature)
        test_loss = evaluate(model, test_loader, device, args.temperature)

        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/test', test_loss, epoch)

        print(f"Epoch {epoch:04d}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")

        if epoch % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")

    writer.close()

if __name__ == '__main__':
    main()
