# train.py
import os
import random
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch

from data import load_tcga_graphs, split_graphs, get_dataloaders, apply_virtual_knockdown, to_device
from models import GraphEncoder, GRACEModel
from utils import set_seed, symmetric_node_contrast_loss

def train_epoch(model, loader, optimizer, device, tau):
    model.train()
    total_loss = 0.0
    num_graphs = 0
    
    for batch in loader:
        data_list = batch.to_data_list()
        batch_loss = 0.0
        for graph in data_list:
            graph = to_device(graph, device)
            num_nodes = graph.x.size(0)
            if num_nodes == 0:
                continue
            # Generate two augmented views (randomly knock down different nodes)
            idx1 = random.randint(0, num_nodes - 1)
            idx2 = random.randint(0, num_nodes - 1)
            if num_nodes > 1:
                while idx2 == idx1:
                    idx2 = random.randint(0, num_nodes - 1)
            view1 = apply_virtual_knockdown(graph, idx1)
            view2 = apply_virtual_knockdown(graph, idx2)
            batch_view1 = Batch.from_data_list([view1])
            batch_view2 = Batch.from_data_list([view2])
            
            z1 = model(batch_view1.x, batch_view1.edge_index, batch_view1.edge_attr)
            z2 = model(batch_view2.x, batch_view2.edge_index, batch_view2.edge_attr)
            loss_graph = symmetric_node_contrast_loss(z1, z2, tau)
            batch_loss += loss_graph
            num_graphs += 1
        if num_graphs > 0:
            batch_loss = batch_loss / num_graphs
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()
    avg_loss = total_loss / len(loader) if len(loader) > 0 else 0
    return avg_loss

def evaluate(model, loader, device, tau):
    model.eval()
    total_loss = 0.0
    num_graphs = 0
    with torch.no_grad():
        for batch in loader:
            data_list = batch.to_data_list()
            for graph in data_list:
                graph = to_device(graph, device)
                num_nodes = graph.x.size(0)
                if num_nodes == 0:
                    continue
                idx1 = random.randint(0, num_nodes - 1)
                idx2 = random.randint(0, num_nodes - 1)
                if num_nodes > 1:
                    while idx2 == idx1:
                        idx2 = random.randint(0, num_nodes - 1)
                view1 = apply_virtual_knockdown(graph, idx1)
                view2 = apply_virtual_knockdown(graph, idx2)
                batch_view1 = Batch.from_data_list([view1])
                batch_view2 = Batch.from_data_list([view2])
                z1 = model(batch_view1.x, batch_view1.edge_index, batch_view1.edge_attr)
                z2 = model(batch_view2.x, batch_view2.edge_index, batch_view2.edge_attr)
                loss_graph = symmetric_node_contrast_loss(z1, z2, tau)
                total_loss += loss_graph.item()
                num_graphs += 1
    avg_loss = total_loss / num_graphs if num_graphs > 0 else 0
    return avg_loss

def main():
    parser = argparse.ArgumentParser(description="GRACE for TCGA Gene Regulatory Network Representation Learning")
    
    # File paths and directory settings
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='path/to/tcga_graphs.pkl',
                        help='Path to TCGA graph data pickle file')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='path/to/checkpoints',
                        help='Checkpoint save directory')
    parser.add_argument('--log-dir', type=str, default='./runs',
                        help='TensorBoard log directory')
    
    # Hyperparameters
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
    parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--num-epochs', type=int, default=3000, help='Number of epochs')
    parser.add_argument('--tau', type=float, default=0.25, help='Temperature parameter tau')
    
    # Model parameters
    parser.add_argument('--hidden-channels', type=int, default=64, help='GraphEncoder hidden dimension')
    parser.add_argument('--out-channels', type=int, default=64, help='GraphEncoder output dimension')
    parser.add_argument('--num-heads', type=int, default=8, help='Number of TransformerConv heads')
    parser.add_argument('--proj-hidden-dim', type=int, default=64, help='Projection head intermediate dimension')
    parser.add_argument('--proj-out-dim', type=int, default=64, help='Projection head output dimension')
    
    args = parser.parse_args()
    
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Include hyperparameters in directory name
    run_name = f"lr{args.learning_rate}_tau{args.tau}_projout{args.proj_out_dim}"
    log_dir = os.path.join(args.log_dir, run_name)
    checkpoint_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)
    
    # Load graph data and create DataLoaders
    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)
    
    # Initialize model: input dimension is obtained from the first graph
    in_channels = graphs[0].x.shape[1]
    encoder = GraphEncoder(in_channels, args.hidden_channels, args.out_channels, num_heads=args.num_heads)
    # encoder_out_dim is treated as out_channels here
    model = GRACEModel(encoder, proj_hidden_dim=args.proj_hidden_dim, proj_out_dim=args.proj_out_dim, encoder_out_dim=args.out_channels).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
    
    # Training loop
    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, args.tau)
        test_loss = evaluate(model, test_loader, device, args.tau)
        
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/test', test_loss, epoch)
        print(f"Epoch {epoch:04d}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")
        
        if epoch % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved: {checkpoint_path}")
    
    writer.close()

if __name__ == '__main__':
    main()
