import json
import argparse
import numpy as np
import torch
from dataset import get_dataset, get_dataset2,get_dataset3, get_dataset_condensed
from utils import seed_everything,generate_condensed_z_y
from emb_test import Embed_test
from link_pred import Link_pred


def results_test(args, data, condensed_graph):
    results = []
    for ep in range(args.runs):
        if args.task == 'nc':
            agent= Embed_test(args=args, ori_data=data, con_data=condensed_graph)
        elif args.task == 'lp':
            model_name = f"{args.model}_lp"
            new_args = argparse.Namespace(**vars(args))
            new_args.model = model_name
            if args.gc_method == 'pgc':
                agent = Link_pred(args=new_args, ori_data=data, con_data=condensed_graph)
            else:
                agent = Link_pred(args=new_args, ori_data=data, con_data=condensed_graph)
        result = agent.model_train()
        results.append(result)
        #torch.save(tensor, f'{args.dataset}_{ep}.pt')
    keys = results[0].keys()
    for key in keys:
        values = [result[key] for result in results]
        mean = np.mean(values)*100
        std = np.std(values)*100
        print(f"{key.upper()} = {mean:.2f}±{std:.2f}")



parser = argparse.ArgumentParser()
parser.add_argument("--gpu_id", type=int, default=0, help="gpu id")
parser.add_argument("--seed", type=int, default=2026)
parser.add_argument("--config", type=str, default='./config/config_init.json')
parser.add_argument("--runs", type=int, default=5)

# gcdm gcond sfgc sgdd gdem cgc pgc
parser.add_argument("--gc_method", type=str, default="pgc")
parser.add_argument("--dataset", type=str, default="pubmed")
parser.add_argument("--reduction_rate", type=float, default=0.5)
parser.add_argument("--hard", type=bool, default=0)

# mlp gcn sgc appnp graphconv gat graphsage gcn2 ssgc  bernnet gprgnn
parser.add_argument("--model", type=str, default="sgc")
parser.add_argument("--task", type=str, default='nc')

parser.add_argument("--fine_tune", type=bool, default=0)

args = parser.parse_args([])




def gnn_test():
    #for i in ['mlp','sgc','gcn','appnp', 'graphconv','gat','graphsage','gcn2','ssgc','bernnet','gprgnn']:
    for i in ['sgc']:
        args.model = i
        #print(i)
        with open(args.config, "r") as config_file:
            config = json.load(config_file)
        if args.model in config:
            config = config[args.model]
        for key, value in config.items():
            setattr(args, key, value)

        torch.cuda.set_device(args.gpu_id)
        seed_everything(args.seed)

        if 'pgc' in args.gc_method:
            args.hard = False

        if args.dataset in ['arxiv_topic', 'arxiv_year', 'hm_class', 'hm_regre', 'arxiv_topic_s', 'arxiv_year_s', 'hm_class_s', 'hm_regre_s',
                            'arxiv_topic_s_0.15', 'arxiv_topic_s_0.3', 'arxiv_topic_s_0.45', 'arxiv_topic_s_0.6', 'arxiv_topic_s_0.75',
                            'hm_class_s_0.15', 'hm_class_s_0.3', 'hm_class_s_0.45', 'hm_class_s_0.6', 'hm_class_s_0.75']:
            ori_data = get_dataset2(args.dataset)

        elif args.dataset in ['cora', 'pubmed', 'citeseer']:
            ori_data = get_dataset(args.dataset)

        elif args.dataset in ['cora_0.25_1', 'cora_0.25_2', 'cora_0.25_3', 'cora_0.25_4', 'cora_0.5_1', 'cora_0.5_2', 'citeseer_0.25_1', 'citeseer_0.25_2',
                              'citeseer_0.25_3', 'citeseer_0.25_4','citeseer_0.5_1', 'citeseer_0.5_2', 'pubmed_0.25_1', 'pubmed_0.25_2', 'pubmed_0.25_3', 'pubmed_0.25_4',
                              'pubmed_0.5_1', 'pubmed_0.5_2']:
            ori_data = get_dataset3(args.dataset)


        if args.dataset == 'arxiv_year_s' and 'pgc' in args.gc_method:
            args.dataset = 'arxiv_topic_s'
            condensed_data = get_dataset_condensed(args.gc_method, args.dataset, args.reduction_rate)

            sys_label_init = generate_condensed_z_y(ori_data, condensed_data.map)
            condensed_data.y = sys_label_init
        else:
            condensed_data = get_dataset_condensed(args.gc_method, args.dataset, args.reduction_rate)



        # results_test(args, ori_data.cuda(), condensed_data.cuda())

        con_num = condensed_data.x.shape[0]
        ori_num = ori_data.x.shape[0]
        #print(con_num)
        #print(con_num*100/ori_num)



        if 'hm' in args.dataset and args.model != 'mlp':
            args.K = 0
            args.hidden_dim = 64
            args.dropout = 0.0
            if args.dataset == 'hm_regre_s':
                args.epochs = 200
                args.weight_decay = 0

        results_test(args, ori_data.cuda(), condensed_data.cuda())


gnn_test()