import os
import argparse
from torch.utils.tensorboard import SummaryWriter
import torch

from data import load_tcga_graphs, split_graphs, get_loaders
from model import SimGRACEModel
from loss import nt_xent_loss
from utils import set_seed, to_device

def main():
    parser = argparse.ArgumentParser(description="SimGRACE pretraining on TCGA GRNs")
    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--log-dir', type=str, default='/path/to/Pretrain/simGRACE/runs/simgrace_tcga')
    parser.add_argument('--checkpoint-dir', type=str, default='/path/to/Pretrain/simGRACE/checkpoints/simgrace_tcga')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--num-epochs', type=int, default=3000)
    parser.add_argument('--eta', type=float, default=1.0, help='perturbation scale')
    parser.add_argument('--tau', type=float, default=0.25, help='temperature')
    parser.add_argument('--hidden-dim', type=int, default=64)
    parser.add_argument('--out-dim', type=int, default=64)
    parser.add_argument('--num-heads', type=int, default=8)
    parser.add_argument('--device', type=str, default='cuda:0')

    args = parser.parse_args()
    set_seed(args.seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    writer = SummaryWriter(log_dir=args.log_dir)

    # Data
    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_graphs, test_graphs = split_graphs(graphs, train_ratio=0.8, seed=args.seed)
    train_loader, test_loader = get_loaders(train_graphs, test_graphs, batch_size=args.batch_size)

    # Model
    in_dim = train_graphs[0].x.shape[1]
    model = SimGRACEModel(in_dim=in_dim,
                         hidden_dim=args.hidden_dim,
                         out_dim=args.out_dim,
                         num_heads=args.num_heads,
                         eta=args.eta).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    for epoch in range(1, args.num_epochs + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = to_device(batch, device)
            g1, g2 = model(batch.x, batch.edge_index, getattr(batch, 'edge_attr', None), batch.batch)
            loss = nt_xent_loss(g1, g2, tau=args.tau)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_train = total_loss / len(train_loader)

        model.eval()
        total_val = 0.0
        with torch.no_grad():
            for batch in test_loader:
                batch = to_device(batch, device)
                g1, g2 = model(batch.x, batch.edge_index, getattr(batch, 'edge_attr', None), batch.batch)
                loss = nt_xent_loss(g1, g2, tau=args.tau)
                total_val += loss.item()
        avg_val = total_val / len(test_loader)

        writer.add_scalar('loss/train', avg_train, epoch)
        writer.add_scalar('loss/val', avg_val, epoch)
        print(f"[Epoch {epoch:03d}] Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f}")

        if epoch % 100 == 0:
            torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f'simgrace_epoch_{epoch}.pt'))

    torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, 'final.pt'))
    writer.close()

if __name__ == '__main__':
    main()
