# train.py
import os
import argparse
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch

from data import load_tcga_graphs, split_graphs, get_dataloaders, to_device
from models import GraphEncoder, CustomDecoder, GraphModel
from utils import set_seed

def reconstruction_loss(node_features_recon, edge_features_recon, data):
    """Reconstruction loss (MSE) for nodes and edges"""
    node_loss = F.mse_loss(node_features_recon, data.x)
    # Squeeze data.edge_attr to [N] for loss calculation
    edge_loss = F.mse_loss(edge_features_recon, data.edge_attr.squeeze(-1))
    return node_loss + edge_loss

def train_epoch(model, optimizer, train_loader, device):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        z, node_features_recon, edge_features_recon = model(batch)
        loss = reconstruction_loss(node_features_recon, edge_features_recon, batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def test_epoch(model, test_loader, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            z, node_features_recon, edge_features_recon = model(batch)
            loss = reconstruction_loss(node_features_recon, edge_features_recon, batch)
            total_loss += loss.item()
    return total_loss / len(test_loader)

def main():
    parser = argparse.ArgumentParser(description="Graph AutoEncoder for TCGA Gene Regulatory Network Representation Learning")
    # File paths and directory settings
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='path/to/tcga_graphs.pkl',
                        help='Path to TCGA graph data pickle file')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='path/to/checkpoints',
                        help='Checkpoint save directory')
    parser.add_argument('--log-dir', type=str, default='./runs',
                        help='TensorBoard log directory')
    # Hyperparameters
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--batch-size', type=int, default=4, help='Batch size')
    parser.add_argument('--num-epochs', type=int, default=3000, help='Number of epochs')
    parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate')
    # Model parameters
    parser.add_argument('--in-channels', type=int, default=1, help='Node feature dimension')
    parser.add_argument('--hidden-channels', type=int, default=64, help='Hidden layer dimension')
    parser.add_argument('--out-channels', type=int, default=64, help='Embedding dimension')
    
    args = parser.parse_args()

    # Set random seed
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create directories for logs and checkpoints
    run_name = f"lr{args.learning_rate}_outdim{args.out_channels}"
    log_dir = os.path.join(args.log_dir, run_name)
    checkpoint_dir = os.path.join(args.checkpoint_dir, run_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=log_dir)

    # Load and split data, create DataLoaders
    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, args.batch_size)

    # Transfer graphs to device if needed
    train_graphs = [to_device(graph, device) for graph in train_graphs]
    test_graphs  = [to_device(graph, device) for graph in test_graphs]
    # DataLoader is already created from graph list, process each batch here

    # Initialize model
    encoder = GraphEncoder(args.in_channels, args.hidden_channels, args.out_channels)
    decoder = CustomDecoder(args.out_channels, num_features=args.in_channels)
    model = GraphModel(encoder, decoder).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

    # Training loop
    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, optimizer, train_loader, device)
        test_loss = test_epoch(model, test_loader, device)
        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Loss/Test', test_loss, epoch)
        print(f"Epoch {epoch:04d}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}")
        
        if epoch % 100 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'GAE_epoch{epoch}.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch}: {checkpoint_path}")

    writer.close()

if __name__ == '__main__':
    main()
