import argparse
import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_scatter import scatter

from logger import Logger, SimpleLogger
from dataset import load_nc_dataset
from data_utils import normalize, gen_normalized_adjs, evaluate, evaluate_whole_graph, eval_acc, eval_rocauc, eval_f1, to_sparse_tensor, load_fixed_splits
from parse import parse_method_base, parse_method_ours, parser_add_main_args

# NOTE: for consistent data splits, see data_utils.rand_train_test_idx
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
fix_seed(0)

### Parse args ###

def parse_args(parser, args=None, namespace=None):
    args, argv = parser.parse_known_args(args, namespace)
    # Remove the unreconized arguments by commenting the following lines...
    # if argv:
    #     msg = _('unrecognized arguments: %s')
    #     self.error(msg % ' '.join(argv))
    return args

parser = argparse.ArgumentParser(description='General Training Pipeline')
parser_add_main_args(parser)
# args = parser.parse_args()
args = parse_args(parser)
print(args)

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

def get_dataset(dataset, year=None):
    ### Load and preprocess data ###
    if dataset == 'ogb-arxiv':
        dataset = load_nc_dataset(args.data_dir, 'ogb-arxiv', year=year)
    else:
        raise ValueError('Invalid dataname')

    if len(dataset.label.shape) == 1:
        dataset.label = dataset.label.unsqueeze(1)

    dataset.n = dataset.graph['num_nodes']
    dataset.c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
    dataset.d = dataset.graph['node_feat'].shape[1]

    dataset.graph['edge_index'], dataset.graph['node_feat'] = \
        dataset.graph['edge_index'], dataset.graph['node_feat']

    return dataset

if args.dataset == 'ogb-arxiv':
    tr_year, val_year, te_years = [[1950, 2011]], [[2011, 2014]], [[2014, 2016], [2016, 2018], [2018, 2020]]
    # te_years = [te_years[-1]]
    dataset_tr = get_dataset(dataset='ogb-arxiv', year=tr_year[0])
    dataset_val = get_dataset(dataset='ogb-arxiv', year=val_year[0])
    datasets_te = [get_dataset(dataset='ogb-arxiv', year=te_years[i]) for i in range(len(te_years))]
    # datasets_te = [get_dataset(dataset='ogb-arxiv', year=te_years[0])]
else:
    raise ValueError('Invalid dataname')

print(f"Train num nodes {dataset_tr.n} | target nodes {dataset_tr.test_mask.sum()} | num classes {dataset_tr.c} | num node feats {dataset_tr.d}")
print(f"Val num nodes {dataset_val.n} | target nodes {dataset_val.test_mask.sum()} | num classes {dataset_val.c} | num node feats {dataset_val.d}")
# for i in range(len(te_years)):
#     dataset_te = datasets_te[i]
#     print(f"Test {i} num nodes {dataset_te.n} | target nodes {dataset_te.test_mask.sum()} | num classes {dataset_te.c} | num node feats {dataset_te.d}")

if __name__ == "__main__":
    ### Load method ###
    if args.method == 'erm':
        model = parse_method_base(args, dataset_tr, device)
    else:
        model = parse_method_ours(args, dataset_tr, device)

    # using rocauc as the eval function
    criterion = nn.NLLLoss()
    eval_func = eval_acc

    logger = Logger(args.runs, args)

    model.train()
    print('MODEL:', model)
    print('DATASET:', args.dataset)

    ### Training loop ###
    for run in range(args.runs):
        model.reset_parameters()
        if args.method == 'erm':
            optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        elif args.method == 'eerm':
            optimizer_gnn = torch.optim.AdamW(model.gnn.parameters(), lr=args.lr, weight_decay=args.weight_decay)
            optimizer_aug = torch.optim.AdamW(model.gl.parameters(), lr=args.lr_a)
        best_val = float('-inf')
        for epoch in range(args.epochs):
            model.train()
            if args.method == 'erm':
                optimizer.zero_grad()
                loss = model(dataset_tr, criterion)
                loss.backward()
                optimizer.step()
            elif args.method == 'eerm':
                model.gl.reset_parameters()
                for m in range(args.T):
                    Var, Mean, Log_p = model(dataset_tr, criterion)
                    outer_loss = Var + args.beta * Mean
                    reward = Var.detach()
                    inner_loss = - reward * Log_p
                    if m == 0:
                        optimizer_gnn.zero_grad()
                        outer_loss.backward()
                        optimizer_gnn.step()
                    optimizer_aug.zero_grad()
                    inner_loss.backward()
                    optimizer_aug.step()

            accs, test_outs = evaluate_whole_graph(args, model, dataset_tr, dataset_val, datasets_te, eval_func)
            logger.add_result(run, accs)

            if epoch % args.display_step == 0:
                if args.method == 'erm':
                    print(f'Epoch: {epoch:02d}, '
                      f'Loss: {loss:.4f}, '
                      f'Train: {100 * accs[0]:.2f}%, '
                      f'Valid: {100 * accs[1]:.2f}%, ')
                    test_info = ''
                    for test_acc in accs[2:]:
                        test_info += f'Test: {100 * test_acc:.2f}% '
                    print(test_info)
                elif args.method == 'eerm':
                    print(f'Epoch: {epoch:02d}, '
                          f'Mean Loss: {Mean:.4f}, '
                          f'Var Loss: {Var:.4f}, '
                          f'Train: {100 * accs[0]:.2f}%, '
                          f'Valid: {100 * accs[1]:.2f}%, ')
                    test_info = ''
                    for test_acc in accs[2:]:
                        test_info += f'Test: {100 * test_acc:.2f}% '
                    print(test_info)

        logger.print_statistics(run)


