import argparse
import numpy as np
from GNNs import GNN
from utils import get_gpu_memory_map
from utils import *
import torch
import random
import time
import sys
import matplotlib.pyplot as plt
import wandb
import socket


class NCDataset(object):
    def __init__(self, name):
        self.name = name
        self.graph = {}
        self.label = None

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))


parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
parser.add_argument('--dataset', type=str, default='ogb-arxiv')
parser.add_argument('--train_epochs', type=int, default=2000)
parser.add_argument('--hidden', type=int, default=32)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--nlayers', type=int, default=5)
parser.add_argument('--model', type=str, default='GCN')
parser.add_argument('--loss', type=str, default='LC')
parser.add_argument('--debug', type=int, default=1)
parser.add_argument('--ood', type=int, default=1)
parser.add_argument('--with_bn', type=int, default=1)
parser.add_argument('--tune', type=int, default=1)
parser.add_argument('--finetune', type=int, default=0, help='whether to finetune the model')

parser.add_argument('--method_bn', type=str, default='BNSA+BNPA')
parser.add_argument('--lr_te', type=float, default=0.001)
parser.add_argument('--tta_epochs', type=int, default=100)
parser.add_argument('--tta_epochs_bns', type=int, default=1)
parser.add_argument('--bin_num', type=int, default=10)
parser.add_argument('--learn_mask_epochs', type=int, default=300)
parser.add_argument('--lr_mask', type=float, default=0.1)
parser.add_argument('--loss_lambda', type=float, default=0.1)
parser.add_argument('--momentum', type=float, default=0.1)

args = parser.parse_args()
reset_args(args)
print('---------Used Para----------')
print(args)
print('----------------------------')

if args.ood:
    path = 'GraphOOD/'
    if args.dataset == 'elliptic':
        path = path + 'temp_elliptic'
        sys.path.append(path)
        from main_as_utils import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]
    elif args.dataset == 'fb100':
        path = path + 'multigraph'
        sys.path.append(path)
        from main_as_utils_fb import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]
    elif args.dataset == 'amazon-photo':
        path = path + 'synthetic'
        sys.path.append(path)
        from main_as_utils_photo import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]
    elif args.dataset == 'ogb-products':
        path = path + 'sales_products'
        sys.path.append(path)
        from main_as_utils_products import datasets_train, datasets_valid, datasets_test
        data = [datasets_train, datasets_valid, datasets_test]
    else:
        if args.dataset == 'cora':
            path = path + 'synthetic'
        elif args.dataset == 'ogb-arxiv':
            path = path + 'temp_arxiv'
        elif args.dataset == 'twitch-e':
            path = path + 'multigraph'
        else:
            raise NotImplementedError
        sys.path.append(path)
        from main_as_utils import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]


# random seed setting
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

cuda_select = 1
if cuda_select:
    torch.cuda.set_device(args.gpu_id)
    torch.cuda.manual_seed(args.seed)
    mem_st = get_gpu_memory_map()

agent = GNN(data, args)
agent.jem_training()
