import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import argparse
from torch_geometric.loader import DataLoader
from models import GINModel
from data_loader import NPYGraphDataset

def train(model, loader, optimizer, device, task):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, getattr(data, 'batch', None))
        if task == 'graph':
            loss = F.nll_loss(out, data.y.view(-1))
        else:
            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device, split='val', task='graph'):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, getattr(data, 'batch', None))
            if task == 'graph':
                pred = out.argmax(dim=1)
                correct += (pred == data.y.view(-1)).sum().item()
                total += data.y.size(0)
            else:
                mask = getattr(data, f'{split}_mask')
                pred = out.argmax(dim=1)[mask]
                correct += (pred == data.y[mask]).sum().item()
                total += mask.sum().item()
    return correct / total

def main():
    for i in ['mol_confound_1_others_causal','mol_confound_2_others_causal','mol_confound_3_others_causal','mol_confound_4_others_causal','mol_confound_5_others_causal']:
        print(f"{i} round:")
        parser = argparse.ArgumentParser()
        parser.add_argument('--dataset', type=str, default='molecular', choices=['molecular'])
        args = parser.parse_args()

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if args.dataset == 'molecular':
            task = 'graph'
            train_dataset = NPYGraphDataset(f'./data/molecular/train_{i}.npy', task=task)
            val_dataset   = NPYGraphDataset(f'./data/molecular/val_{i}.npy', task=task)
            test_dataset  = NPYGraphDataset(f'./data/molecular/test_{i}.npy', task=task)

        
        if task == 'graph':
            loader_args = {'batch_size': 32, 'shuffle': True}

        train_loader = DataLoader(train_dataset, **loader_args)
        val_loader   = DataLoader(val_dataset, **loader_args)
        test_loader  = DataLoader(test_dataset, **loader_args)

        
        model = GINModel(train_dataset.num_features,
                        int(train_dataset.data.y.max()) + 1,
                        task=task).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

        best_val_acc = 0
        epoches = 50
        val_acc_list_5_epoch = []
        final_train_accuracies = []
        final_test_accuracies = []
        final_in_test_accuracies = []

        for epoch in range(1, epoches + 1):
            loss = train(model, train_loader, optimizer, device, task)
            train_acc = evaluate(model, train_loader, device, 'train', task)
            val_acc = evaluate(model, val_loader, device, 'val', task)
            test_acc = evaluate(model, test_loader, device, 'test', task)

            
            final_train_accuracies.append(train_acc)
            final_in_test_accuracies.append(test_acc)  
            final_test_accuracies.append(test_acc)  

            val_acc_list_5_epoch.append(val_acc)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), f'./best_model_{args.dataset}.pt')

            print(f'Epoch: {epoch}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

            
            if epoch % 5 == 0:
                std_val = np.std(val_acc_list_5_epoch)
                print(f"[Epoch {epoch - 4}~{epoch}] Val Acc Std: {std_val:.4f}")
                val_acc_list_5_epoch = []

        
        final_train_mean = np.mean(final_train_accuracies)*100
        final_train_std = np.std(final_train_accuracies)*100
        final_test_mean = np.mean(final_test_accuracies)
        final_test_std = np.std(final_test_accuracies)
        final_in_test_mean = np.mean(final_in_test_accuracies)
        final_in_test_std = np.std(final_in_test_accuracies)*100

        
        model.load_state_dict(torch.load(f'./best_model_{args.dataset}.pt'))
        final_test_acc = evaluate(model, test_loader, device, 'test', task)*100
        print(f'Train Acc  {final_train_mean:.2f} ± {final_train_std:.2f}  Test Acc  {final_test_acc:.2f} ± {final_in_test_std:.2f}')

if __name__ == '__main__':
        main()
