import argparse
import numpy as np
from GOAT import GraphAgent
from utils import *
import torch
import random
import time
import sys
import warnings
from utils  import visualize_prompt_distribution

warnings.filterwarnings("ignore")
st = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
parser.add_argument('--dataset', type=str, default='ogb-arxiv')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--patience', type=int, default=3)
parser.add_argument('--hidden', type=int, default=32)
parser.add_argument('--lr', type=float, default=0.01, help='pre-train lr')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='pre-train naive GNN\'s weight decay')
parser.add_argument('--model', type=str, default='GCN')
parser.add_argument('--normalize_features', type=bool, default=True)
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--nlayers', type=int, default=5, help='number of pre-train GNN layers')

parser.add_argument('--lr_feat', type=float, default=1e-5, help='test tuning lr')
parser.add_argument('--loss', type=str, default=["La2a", "Ls", "Lc", "Lr"],
                    help='test tuning loss function, use ["La2a", "Ls", "Lc", "Lr"] as full GOAT-A2A loss function. ')
parser.add_argument('--ood', type=int, default=1, help='set to 1 to allowing OOD dataset according to EERM.')
parser.add_argument('--debug', type=int, default=1)
parser.add_argument('--with_bn', type=int, default=1, help='whether to bn, False in GPR')
parser.add_argument('--margin', type=float, default=-1)
parser.add_argument('--existing_space', type=int, default=1, help='enable removing edges from the graph')
parser.add_argument('--test_val', type=int, default=0, help='set to 1 to evaluate performance on validation data')
parser.add_argument('--visualization', type=bool, default=False, help='visualize embeddings by T-SNE on test graph')
parser.add_argument('--distribution_visable', type=bool, default=True, help='visualize embeddings by T-SNE on test graph')
parser.add_argument('--show_SVD', type=bool, default=False, help='visualize SVD in LRA on test graph')
parser.add_argument('--tune', type=int, default=0, help='set to 0 to use the best model')
parser.add_argument('--finetune', type=int, default=0, help='test-time whether to tune the model')
parser.add_argument('--tent', type=int, default=0, help='use the Tent for finetuning (need to set finetune=1)')
parser.add_argument('--strategy', type=str, default=['dropedge', 'dropedge'], help='how to get the environment sample')
parser.add_argument('--ablation', type=int, default=0,
                    help="{0: 'ALL', 1: 'without Lr', 2: 'Only Ls'', 3: 'Only one view'}")
parser.add_argument('--prompt_comp', type=str, default="",
                    help="prompt comparison: multiG, UPF, All in One, you need to set mlp_prompt True")

parser.add_argument('--prompt', type=bool, default=True, help="whether to use prompt")
parser.add_argument('--mlp_prompt', type=bool, default=True, help="whether to use prompt generator")
parser.add_argument('--LR', type=bool, default=True, help="whether to use Low-Rank prompt generator")
parser.add_argument('--prompt_layers', type=int, default=1, help="num layers used to generate the prompt")
parser.add_argument('--attn_ratio', type=int, default=4, help="the reduction ratio by the input embeddings")
parser.add_argument('--virtual_nodes_ratio', type=int, default=5, help="the this_ratio*cls = rank of prompt attn")
parser.add_argument('--alpha', type=float, default=0.5, help="controls the ratio between the Lc to Ls")
parser.add_argument('--lamb', type=float, default=0.1, help="controls the ratio between the Lr to others")

args = parser.parse_args()
torch.cuda.set_device(args.gpu_id)

if __name__ == '__main__':
    import os

    os.environ['JOBLIB_THREAD_POOL_DISABLE'] = '1'

    print('===========')
    # reset_args(args)
    if args.tune:  # set args.tune to 1 to change the model hyperparameters
        lr_feat = args.lr_feat;
        epochs = args.epochs;

    print(args)

    from utils import get_gpu_memory_map

    mem_st = get_gpu_memory_map()

    if args.ood:
        path = 'GraphOOD-EERM/'
        if args.dataset == 'elliptic':
            path = path + 'temp_elliptic'
            sys.path.append(path)
            from main_as_utils import datasets_tr, datasets_val, datasets_te

            data = [datasets_tr, datasets_val, datasets_te]
        elif args.dataset == 'fb100':
            path = path + 'multigraph'
            sys.path.append(path)
            from main_as_utils_fb import datasets_tr, datasets_val, datasets_te

            data = [datasets_tr, datasets_val, datasets_te]
        elif args.dataset == 'amazon-photo':
            path = path + 'synthetic'
            sys.path.append(path)
            from main_as_utils_photo import dataset_tr, dataset_val, datasets_te

            data = [dataset_tr, dataset_val, datasets_te]
        else:
            if args.dataset == 'cora':
                path = path + 'synthetic'
            elif args.dataset == 'ogb-arxiv':
                path = path + 'temp_arxiv'
            elif args.dataset == 'twitch-e':
                path = path + 'multigraph'
            else:
                raise NotImplementedError
            sys.path.append(path)
            from main_as_utils import dataset_tr, dataset_val, datasets_te

            data = [dataset_tr, dataset_val, datasets_te]
    else:
        data = get_dataset(args.dataset, args.normalize_features)

    """
    should be convert into the model params.csv
    """
    if args.dataset == "ogb-arxiv":
        args.epochs = 5
    if args.dataset == "twitch-e":
        args.dropout = 0.9

    # random seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    res = []
    agent = GraphAgent(data, args)
    prompt_dict = []

    if args.test_val:
        print('using validation as test...')
        data[-1] = data[-2]
        if type(data[-1]) is not list:
            data[-1] = [data[-1]]
        y_te, out_te = [], []
        for ix, test_data in enumerate(data[-1]):
            if args.finetune:
                acc, output, labels = agent.finetune(test_data)
            else:
                acc, output, labels, prompt = agent.learn_graph(test_data, ix, visualization=args.visualization)
            prompt_dict.append(prompt)
            res.append(acc)
            y_te.append(labels)
            out_te.append(output)

            if args.debug == 2:
                break
        # acc_te = agent.model.eval_func(torch.cat(y_te, dim=0), torch.cat(out_te, dim=0))
        # print(f'Results on test sets: {acc_te}')
        # print(f'Flatten Test: {acc_te:.2f}')
        print(f'Results on test sets: {np.mean(res)}')

    else:
        if args.dataset != 'elliptic':
            y_te, out_te = [], []
            for ix, test_data in enumerate(data[-1]):
                if args.finetune:
                    acc, output, labels = agent.finetune(test_data)
                else:
                    acc, output, labels, prompt = agent.learn_graph(test_data, ix, visualization=args.visualization)
                prompt_dict.append(prompt)
                res.append(acc)
                y_te.append(labels)
                out_te.append(output)

                if args.debug == 2:
                    break
            acc_te = agent.model.eval_func(torch.cat(y_te, dim=0), torch.cat(out_te, dim=0))

        else:
            y_te_all, out_te_all = [], []
            y_te, out_te = [], []
            for ii, test_data in enumerate(data[-1]):
                if args.finetune:
                    acc, output, labels = agent.finetune(test_data)
                else:
                    acc, output, labels, prompt = agent.learn_graph(test_data, ii, visualization=args.visualization)
                prompt_dict.append(prompt)
                y_te.append(labels)
                out_te.append(output)
                y_te_all.append(labels)
                out_te_all.append(output)

                if ii % 4 == 0 or ii == len(data[-1]) - 1:
                    acc_te = agent.model.eval_func(torch.cat(y_te, dim=0), torch.cat(out_te, dim=0))
                    res += [float(f'{acc_te:.2f}')]
                    y_te, out_te = [], []
                    if args.debug == 2:
                        break

            acc_te = agent.model.eval_func(torch.cat(y_te_all, dim=0), torch.cat(out_te_all, dim=0))

        print('Results on test sets:', res)
        print('Mean results on test sets:', np.mean(res))

