import os
import argparse
from torch.utils.tensorboard import SummaryWriter
import torch

from afgrl_data import load_tcga_graphs, split_graphs, get_dataloaders, to_device
from afgrl_model import AFGRLModel
from utils_afgrl import set_seed

def train_epoch(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0.0
    count = 0
    for batch in loader:
        for graph in batch.to_data_list():
            graph = to_device(graph, device)
            x, edge_index, edge_attr = graph.x, graph.edge_index, getattr(graph, 'edge_attr', None)
            _, loss, _, _ = model(x, edge_index, edge_attr, epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.update_teacher()

            total_loss += loss.item()
            count += 1
    return total_loss / count if count > 0 else 0.0

def evaluate(model, loader, device, epoch):
    model.eval()
    total_loss = 0.0
    count = 0
    with torch.no_grad():
        for batch in loader:
            for graph in batch.to_data_list():
                graph = to_device(graph, device)
                x, edge_index, edge_attr = graph.x, graph.edge_index, getattr(graph, 'edge_attr', None)
                _, loss, _, _ = model(x, edge_index, edge_attr, epoch)
                total_loss += loss.item()
                count += 1
    return total_loss / count if count > 0 else 0.0

def main():
    parser = argparse.ArgumentParser(description="AFGRL pretraining on TCGA GRNs")
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--log-dir', type=str, default='/path/to/Pretrain/AFGRL/runs/AFGRL_tcga_lr-5')
    parser.add_argument('--checkpoint-dir', type=str, default='/path/to/Pretrain/AFGRL/checkpoints/AFGRL_tcga_lr-5')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--num-epochs', type=int, default=3000)
    parser.add_argument('--topk', type=int, default=4)
    parser.add_argument('--num-centroids', type=int, default=50)
    parser.add_argument('--num-kmeans', type=int, default=5)
    parser.add_argument('--clus-num-iters', type=int, default=20)
    parser.add_argument('--mad', type=float, default=0.9)
    parser.add_argument('--pred-hid', type=int, default=64)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--device', type=int, default=0, help='CUDA device id')

    parser.add_argument('--hidden-channels', type=int, default=64, help='Transformer hidden dimension')
    parser.add_argument('--out-channels', type=int, default=64, help='Encoder output dimension')
    parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads in TransformerConv')

    args = parser.parse_args()

    set_seed(args.seed)
    cuda_dev = args.device
    device = torch.device(f'cuda:{cuda_dev}' if torch.cuda.is_available() else 'cpu')
    print(f"[Setup] Device: {device}")


    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=args.log_dir)


    # load data
    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs, train_ratio=0.8, seed=args.seed)
    train_loader, test_loader = get_dataloaders(train_graphs, test_graphs, batch_size=args.batch_size)

    in_dim = graphs[0].x.shape[1]
    model = AFGRLModel(
        in_channels=in_dim,
        hidden_channels=args.hidden_channels,
        out_channels=args.out_channels,
        num_heads=args.num_heads,
        args=args
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    # training loop
    for epoch in range(1, args.num_epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
        test_loss = evaluate(model, test_loader, device, epoch)
        writer.add_scalar('Loss/train', train_loss, epoch)
        writer.add_scalar('Loss/test', test_loss, epoch)
        print(f"[Epoch {epoch:04d}] Train Loss: {train_loss:.4f} | Eval Loss: {test_loss:.4f}")

        if epoch % 100 == 0:
            ckpt_path = os.path.join(args.checkpoint_dir, f'afgrl_epoch_{epoch}.pt')
            torch.save(model.state_dict(), ckpt_path)
            print(f"[Checkpoint saved] {ckpt_path}")

    writer.close()

if __name__ == '__main__':
    main()
