import argparse
import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_scatter import scatter
from dataset import load_nc_dataset
from parse import parser_add_main_args


def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


fix_seed(0)


def parse_args(parser, args=None, namespace=None):
    args, argv = parser.parse_known_args(args, namespace)
    return args


parser = argparse.ArgumentParser(description='General Training Pipeline')
parser_add_main_args(parser)
args = parse_args(parser)

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")


def get_dataset(dataset, year=None):

    if dataset == 'ogb-arxiv':
        args.data_dir = "GraphOOD/data"
        dataset = load_nc_dataset(args.data_dir, 'ogb-arxiv', year=year)
    else:
        raise ValueError('Invalid dataname')

    if len(dataset.label.shape) == 1:
        dataset.label = dataset.label.unsqueeze(1)

    dataset.n = dataset.graph['num_nodes']
    dataset.c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
    dataset.d = dataset.graph['node_feat'].shape[1]

    dataset.graph['edge_index'], dataset.graph['node_feat'] = \
        dataset.graph['edge_index'], dataset.graph['node_feat']

    return dataset


if args.dataset == 'ogb-arxiv':
    tr_year, val_year, te_years = [[1950, 2018]], [[2018, 2019]], [[2019, 2020]]
    datasets_tr = [get_dataset(dataset='ogb-arxiv', year=tr_year[0])]
    datasets_val = [get_dataset(dataset='ogb-arxiv', year=val_year[0])]
    datasets_te = [get_dataset(dataset='ogb-arxiv', year=te_years[i]) for i in range(len(te_years))]
else:
    raise ValueError('Invalid dataname')
