'''model=GCN; gpu_id=0; python train_feat_adv.py --lr_feat=1e-3 --model=GCN --gpu_id=3 --seed=0 --noise_feature=0 --finetune=0'''
# TODO: best_val_acc
import argparse
import numpy as np
from gtransform_both import GraphAgent
from utils import *
import torch
import random
import time
import sys
from models import *
st = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--epochs', type=int, default=50)
# parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--hidden', type=int, default=32)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lr_feat', type=float, default=0.001)
parser.add_argument('--nlayers', type=int, default=5)
# parser.add_argument('--model', type=str, default='GCN')
parser.add_argument('--model', type=str, default='SAGE')
# parser.add_argument('--model', type=str, default='GPR')
parser.add_argument('--loss', type=str, default='entropy')
parser.add_argument('--ptb_rate', type=float, default=-1)
parser.add_argument('--debug', type=int, default=1)
parser.add_argument('--ood', type=int, default=1)
parser.add_argument('--finetune', type=int, default=0)
parser.add_argument('--noise_feature', type=float, default=0.0)
parser.add_argument('--noise_structure', type=float, default=0.0)
parser.add_argument('--cop', type=int, default=0)
parser.add_argument('--with_bn', type=int, default=1)
parser.add_argument('--lr_adj', type=float, default=0.1)
parser.add_argument('--ratio', type=float, default=0.1)
parser.add_argument('--margin', type=float, default=-1)
parser.add_argument('--existing_space', type=int, default=1)
parser.add_argument('--loop_adj', type=int, default=1)
parser.add_argument('--loop_feat', type=int, default=4)
parser.add_argument('--test_val', type=int, default=0)
parser.add_argument('--tune', type=int, default=0)
parser.add_argument('--dropedge', type=float, default=0)
args = parser.parse_args()

torch.cuda.set_device(args.gpu_id)

lr_feat = args.lr_feat; epochs = args.epochs; ratio = args.ratio; lr_adj = args.lr_adj
print('===========')
reset_args(args)
if args.model == 'GAT':
    args.loop_adj = 0
if args.tune:
    args.lr_feat = lr_feat; args.epochs = epochs; args.ratio = ratio; args.lr_adj = lr_adj
print(args)


if args.ood:
    path = 'GraphOOD-EERM/'
    if args.dataset == 'elliptic':
        path = path + 'temp_elliptic'
        sys.path.append(path)
        from main_as_utils import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]
    elif args.dataset == 'fb100':
        path = path + 'multigraph'
        sys.path.append(path)
        from main_as_utils_fb import datasets_tr, datasets_val, datasets_te
        data = [datasets_tr, datasets_val, datasets_te]
    elif args.dataset == 'amazon-photo':
        path = path + 'synthetic'
        sys.path.append(path)
        from main_as_utils_photo import dataset_tr, dataset_val, datasets_te
        data = [dataset_tr, dataset_val, datasets_te]
    else:
        if args.dataset == 'cora':
            path = path + 'synthetic'
        elif args.dataset == 'ogb-arxiv':
            path = path + 'temp_arxiv'
        elif args.dataset == 'twitch-e':
            path = path + 'multigraph'
        else:
            raise NotImplementedError
        sys.path.append(path)
        from main_as_utils import dataset_tr, dataset_val, datasets_te
        data = [dataset_tr, dataset_val, datasets_te]
else:
    data = get_dataset(args.dataset, args.normalize_features)


# random seed setting
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

res = []
# agent = GraphAgent(data, args)


def pretrain_model(verbose=True):
    data_all = data
    device = 'cuda'
    if type(data_all[0]) is not list:
        feat, labels = data_all[0].graph['node_feat'], data_all[0].label
    else:
        feat, labels = data_all[0][0].graph['node_feat'], data_all[0][0].label
    # reset_args(args)
    if args.model == "GCN":
        save_mem = False
        model = GCN(nfeat=feat.shape[1], nhid=args.hidden, dropout=args.dropout, nlayers=args.nlayers,
                    weight_decay=args.weight_decay, with_bn=True, lr=args.lr, save_mem=save_mem,
                    nclass=max(labels).item()+1, device=device, args=args).to(device)

    elif args.model == "GAT":
        model = GAT(nfeat=feat.shape[1], nhid=32, heads=4, lr=args.lr, nlayers=args.nlayers,
              nclass=labels.max().item() + 1, with_bn=True, weight_decay=args.weight_decay,
              dropout=0.0, device=device, args=args).to(device)
    elif args.model == "SAGE":
        if args.dataset == "fb100":
            model = SAGE2(feat.shape[1], 32, max(labels).item()+1, num_layers=args.nlayers, dropout=0.0, lr=0.01, weight_decay=args.weight_decay, device=device, args=args, with_bn=args.with_bn).to(device)
        else:
            model = SAGE(feat.shape[1], 32, max(labels).item()+1, num_layers=args.nlayers, dropout=0.0, lr=0.01, weight_decay=args.weight_decay, device=device, args=args, with_bn=args.with_bn).to(device)
    elif args.model == "GPR":
        model = GPRGNN(feat.shape[1], 32, max(labels).item()+1, dropout=0.0, lr=0.01, weight_decay=args.weight_decay, device=device, args=args).to(device)
    else:
        raise NotImplementedError
    if verbose: print(model)

    train_iters = 500 if args.dataset == 'ogb-arxiv' else 200
    setattr(model, 'dropedge', args.dropedge)
    model.fit_inductive(data_all, train_iters=train_iters, patience=500, verbose=True)

    def evaluate(model):
        model.eval()
        accs = []
        y_te, out_te = [], []
        y_te_all, out_te_all = [], []
        for ii, test_data in enumerate(data_all[2]):
            x, edge_index = test_data.graph['node_feat'], test_data.graph['edge_index']
            x, edge_index = x.to(device), edge_index.to(device)
            output = model.predict(x, edge_index)
            labels = test_data.label.to(device) #.squeeze()
            eval_func = model.eval_func
            if args.dataset in ['ogb-arxiv']:
                acc_test = eval_func(labels[test_data.test_mask], output[test_data.test_mask])
                accs.append(acc_test)
                y_te_all.append(labels[test_data.test_mask])
                out_te_all.append(output[test_data.test_mask])
            elif args.dataset in ['cora', 'amazon-photo', 'twitch-e', 'fb100']:
                acc_test = eval_func(labels, output)
                accs.append(acc_test)
                y_te_all.append(labels)
                out_te_all.append(output)
            elif args.dataset in ['elliptic']:
                acc_test = eval_func(labels[test_data.mask], output[test_data.mask])
                y_te.append(labels[test_data.mask])
                out_te.append(output[test_data.mask])
                y_te_all.append(labels[test_data.mask])
                out_te_all.append(output[test_data.mask])
                if ii % 4 == 0 or ii == len(data_all[2]) - 1:
                    acc_te = eval_func(torch.cat(y_te, dim=0), torch.cat(out_te, dim=0))
                    accs += [float(f'{acc_te:.2f}')]
                    y_te, out_te = [], []
            else:
                raise NotImplementedError
        print('Test accs:', accs)
        acc_te = eval_func(torch.cat(y_te_all, dim=0), torch.cat(out_te_all, dim=0))
        # print(f'flatten test: {acc_te}')

    if verbose: evaluate(model)
    return model

pretrain_model()
