import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops

from logger import *
from data import load_dataset
from utils import eval_acc, eval_rocauc, load_splits, augment_graph, modify_globalMask
from evaluate import *
from parse import parse_model, parser_add_main_args
import os.path as osp
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def fix_seed(seed=110):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def run_semi(args, model):
    all_results = []
    for run in range(args.runs):
            
        fix_seed(args.seed + run)
        
        if args.split in ["random"]:
            split_idx = split_idx_lst[run]
            modify_globalMask(args, dataset, split_idx, run)
            dataset.graph["all_feat"] = torch.cat([dataset.graph["node_feat"], dataset.graph["global_feat"]], dim=0)
        else:
            split_idx = split_idx_lst[0]
        
        model = load_model(args, model, run).to(device)

        result = evaluate(model, dataset, split_idx, eval_func, criterion, args)

        logger.write(f'Train Loss: {result[3]:.4f}, Train Acc: {100 * result[0]:.2f}%, '
                    f'Valid Loss: {result[4]:.4f}, Valid Acc: {100 * result[1]:.2f}%, '
                    f'Test Loss: {result[5]:.4f}, Test Acc: {100 * result[2]:.2f}%')
        all_results.append(result[:-1])

    all_results = torch.tensor(all_results)
    print(all_results.mean(dim=0), all_results.std(dim=0))
    return

### Parse args ###
parser = argparse.ArgumentParser(description='Training Pipeline for Node Classification')
parser_add_main_args(parser)
args = parser.parse_args()

# args.use_res = True
# fix_seed(args.seed)

### set device
if args.cpu:
    device = torch.device("cpu")
else:
    torch.cuda.set_device(args.device)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load and preprocess data ###
data_dir = osp.join(osp.expanduser('~'), 'datasets/')
dataset = load_dataset(data_dir, args.dataset)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(-1)
dataset.label = dataset.label.to(device)

if args.split in ["random"]:
    split_idx_lst = load_splits(data_dir, name=args.dataset)
else:
    split_idx_lst = [dataset.split]

### Basic information of datasets ###
n = dataset.graph['num_nodes']
e = dataset.graph['edge_index'].shape[1]
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph['node_feat'].shape[1]


dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=n)

dataset.graph['edge_index'], dataset.graph['node_feat'] = \
    dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device)
num_edges = dataset.graph['edge_index'].shape[1]

augment_graph(args, dataset)
if args.split == "fixed":
    dataset.graph["all_feat"] = torch.cat([dataset.graph["node_feat"], dataset.graph["global_feat"]], dim=0)
    
### Loss function (Single-class, Multi-class) ###
if args.dataset in ('questions'):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

### Performance metric (Acc, AUC) ###
if args.metric == 'rocauc':
    eval_func = eval_rocauc
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)
logger.write(f"dataset {args.dataset} | num nodes {n} | num edge {e} | num node feats {d} | num classes {c}")
logger.write(args)
model = parse_model(args, n, c, d, dataset.label_global.shape[0], device)
logger.write('MODEL:', model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.write("total parameters: ", total_params, " total trainable parameters: ", trainable_params)

### Training loop ###
# if args.semi:
results = run_semi(args, model)

### Save results ###
if args.save_result:
    save_result(args, results)
os._exit(1)