# train_gca.py

import os
import random
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.data import Batch, Data

from data import (
    load_tcga_graphs,
    split_graphs,
    get_dataloaders,
    to_device,
)
from models import GraphEncoder, GCAModel
from utils import set_seed, symmetric_node_contrast_loss
from augmentations import augment_graph 

def train_epoch(model, loader, optimizer, device, tau, args):
    model.train()
    total_loss = 0.0
    count = 0

    for batch in loader:
        batch_loss = 0.0
        subgraphs = batch.to_data_list()

        for graph in subgraphs:
            graph = to_device(graph, device)
            view1, view2 = augment_graph(
                graph,
                drop_scheme       = args.drop_scheme,
                drop_edge_rate_1  = args.drop_edge_rate_1,
                drop_edge_rate_2  = args.drop_edge_rate_2,
                drop_feature_rate_1 = args.drop_feature_rate_1,
                drop_feature_rate_2 = args.drop_feature_rate_2,
                threshold         = args.threshold,
            )

            b1 = Batch.from_data_list([view1]).to(device)
            b2 = Batch.from_data_list([view2]).to(device)

            z1 = model(b1.x, b1.edge_index, b1.edge_attr)
            z2 = model(b2.x, b2.edge_index, b2.edge_attr)

            loss = symmetric_node_contrast_loss(z1, z2, tau)
            batch_loss += loss
            count += 1

        if count > 0:
            batch_loss = batch_loss / count
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            total_loss += batch_loss.item()

    return total_loss / len(loader)

def evaluate(model, loader, device, tau, args):
    model.eval()
    total_loss = 0.0
    count = 0

    with torch.no_grad():
        for batch in loader:
            for graph in batch.to_data_list():
                graph = to_device(graph, device)
                view1, view2 = augment_graph(
                    graph,
                    args.drop_scheme,
                    args.drop_edge_rate_1,
                    args.drop_edge_rate_2,
                    args.drop_feature_rate_1,
                    args.drop_feature_rate_2,
                    args.threshold,
                )
                b1 = Batch.from_data_list([view1]).to(device)
                b2 = Batch.from_data_list([view2]).to(device)

                z1 = model(b1.x, b1.edge_index, b1.edge_attr)
                z2 = model(b2.x, b2.edge_index, b2.edge_attr)

                total_loss += symmetric_node_contrast_loss(z1, z2, tau).item()
                count += 1

    return total_loss / max(count, 1)

def main():
    parser = argparse.ArgumentParser(
        description="GCA-based GRACE on TCGA GRN"
    )

    parser.add_argument('--tcga-graphs-path', type=str,
                        default='/path/to/data/graph_data_object/tcga_graphs.pkl')
    parser.add_argument('--checkpoint-dir', type=str,
                        default='/path/to/Pretrain/GCA/checkpoints')
    parser.add_argument('--log-dir', type=str,
                        default='/path/to/Pretrain/GCA/runs/')

    parser.add_argument('--seed', type=int,   default=42)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--learning-rate', type=float, default=1e-4)
    parser.add_argument('--num-epochs', type=int,   default=3000)
    parser.add_argument('--tau', type=float,    default=0.4)

    parser.add_argument('--drop-scheme', type=str,
                        choices=['uniform','degree','pr','evc'],
                        default='degree')
    parser.add_argument('--drop-edge-rate-1', type=float, default=0.3)
    parser.add_argument('--drop-edge-rate-2', type=float, default=0.4)
    parser.add_argument('--drop-feature-rate-1', type=float, default=0.1)
    parser.add_argument('--drop-feature-rate-2', type=float, default=0.0)
    parser.add_argument('--threshold', type=float,
                        default=0.7)

    parser.add_argument('--hidden-channels', type=int, default=64)
    parser.add_argument('--out-channels',    type=int, default=64)
    parser.add_argument('--num-heads',       type=int, default=8)
    parser.add_argument('--proj-hidden-dim', type=int, default=64)
    parser.add_argument('--proj-out-dim',    type=int, default=64)

    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    run_name = f"{args.drop_scheme}_er{args.drop_edge_rate_1}-{args.drop_edge_rate_2}_" + \
               f"fr{args.drop_feature_rate_1}-{args.drop_feature_rate_2}"
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, run_name))
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    graphs = load_tcga_graphs(args.tcga_graphs_path)
    train_g, test_g = split_graphs(graphs)
    train_loader, test_loader = get_dataloaders(
        train_g, test_g, args.batch_size
    )

    in_ch = graphs[0].x.size(1)
    encoder = GraphEncoder(in_ch,
                           args.hidden_channels,
                           args.out_channels,
                           args.num_heads).to(device)
    model = GCAModel(
        encoder,
        proj_hidden_dim=args.proj_hidden_dim,
        proj_out_dim=args.proj_out_dim,
        encoder_out_dim=args.out_channels
    ).to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.learning_rate
    )

    for epoch in range(1, args.num_epochs + 1):
        tr_loss = train_epoch(
            model, train_loader, optimizer, device,
            args.tau, args
        )
        te_loss = evaluate(
            model, test_loader, device,
            args.tau, args
        )

        writer.add_scalar('Loss/train', tr_loss, epoch)
        writer.add_scalar('Loss/test',  te_loss, epoch)
        print(f"[Epoch {epoch:04d}] train={tr_loss:.4f}, test={te_loss:.4f}")

        if epoch % 100 == 0:
            ck = os.path.join(
                args.checkpoint_dir,
                f"gca_epoch_{epoch}.pt"
            )
            torch.save(model.state_dict(), ck)
            print("Saved checkpoint:", ck)

    writer.close()

if __name__ == '__main__':
    main()
