import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops

from logger import *
from data import load_dataset
from utils import eval_acc, eval_rocauc, load_splits, augment_graph, modify_globalMask
from evaluate import *
from parse import parse_model, parser_add_main_args
import os.path as osp
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def fix_seed(seed=110):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run_semi(args):
    for run in range(args.runs):
        
        fix_seed(args.seed + run)
        
        if args.split in ["random"]:
            split_idx = split_idx_lst[run]
            modify_globalMask(args, dataset, split_idx, run)
            dataset.graph["all_feat"] = torch.cat([dataset.graph["node_feat"], dataset.graph["global_feat"]], dim=0)
        else:
            split_idx = split_idx_lst[0]
            
        train_idx = split_idx['train'].to(device)
        
        model.reset_parameters()
        optimizer = torch.optim.Adam([{"params": model.parameters()}], weight_decay=args.wd, lr=args.lr)
        
        best_val_ori = float('-inf')
        best_test_ori = float('-inf')

        for epoch in range(1, args.epochs+1):
                
            model.train()
            optimizer.zero_grad()

            edge_mask = dataset.graph['edge_masks']
    
            out = model(dataset.graph['all_feat'], edge_mask)         
            out = F.log_softmax(out, dim=-1)
            
            loss1, loss2 = model.cls_loss(out, dataset.label, dataset.label_global, train_idx, criterion)
            
            (loss1 + loss2).backward()
            optimizer.step()

            result = evaluate(model, dataset, split_idx, eval_func, criterion, args, edge_mask=edge_mask)

            logger.add_result(run, result[:-1])
                
            if result[1] > best_val_ori:
                best_val_ori = result[1]
                best_test_ori = result[2]
                if args.save_model:
                    save_model(args, model, optimizer, run)
            
            if epoch % args.display_step == 0:
                logger.write(f'Epoch: {epoch:02d}, '
                    f'Train Loss: {result[3]:.4f}, Train Acc: {100 * result[0]:.2f}%, '
                    f'Valid Loss: {result[4]:.4f}, Valid Acc: {100 * result[1]:.2f}%, '
                    f'Test Loss: {result[5]:.4f}, Test Acc: {100 * result[2]:.2f}%, '
                    f'Best Valid: {100 * best_val_ori:.2f}%, '
                    f'Best Test: {100 * best_test_ori:.2f}%')
        logger.print_statistics(run)
    results = logger.print_statistics()
    return results

### Parse args ###
parser = argparse.ArgumentParser(description='Training Pipeline for Node Classification')
parser_add_main_args(parser)
args = parser.parse_args()

### set device
if args.cpu:
    device = torch.device("cpu")
else:
    torch.cuda.set_device(args.device)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load and preprocess data ###
data_dir = osp.join(osp.expanduser('~'), 'datasets/')
dataset = load_dataset(data_dir, args.dataset)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(-1)
dataset.label = dataset.label.to(device)

if args.split in ["random"]:
    split_idx_lst = load_splits(data_dir, name=args.dataset)
else:
    split_idx_lst = [dataset.split]

### Basic information of datasets ###
n = dataset.graph['num_nodes']
e = dataset.graph['edge_index'].shape[1]
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph['node_feat'].shape[1]


dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=n)

dataset.graph['edge_index'], dataset.graph['node_feat'] = \
    dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device)
num_edges = dataset.graph['edge_index'].shape[1]

augment_graph(args, dataset)
if args.split == "fixed":
    dataset.graph["all_feat"] = torch.cat([dataset.graph["node_feat"], dataset.graph["global_feat"]], dim=0)
    
### Loss function (Single-class, Multi-class) ###
if args.dataset in ('questions'):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

### Performance metric (Acc, AUC) ###
if args.metric == 'rocauc':
    eval_func = eval_rocauc
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)
logger.write(f"dataset {args.dataset} | num nodes {n} | num edge {e} | num node feats {d} | num classes {c}")
logger.write(args)
model = parse_model(args, n, c, d, dataset.label_global.shape[0], device)
logger.write('MODEL:', model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.write("total parameters: ", total_params, " total trainable parameters: ", trainable_params)
# logger.write('HEAD:', head)


### Training loop ###
# if args.semi:
results = run_semi(args)

### Save results ###
if args.save_result:
    save_result(args, results)
os._exit(1)