from prompt_graph.tasker import NodeTask, GraphTask
from prompt_graph.utils import seed_everything
# from torchsummary import summary
from prompt_graph.utils import  get_args
from prompt_graph.data import load4node, split_induced_graphs
import pickle
import os
import ipdb
import torch


def load_induced_graph(dataset_name, data, device):

    folder_path = './dataspace/induced_graph/' + dataset_name
    if not os.path.exists(folder_path):
            os.makedirs(folder_path)

    file_path = folder_path + '/induced_graph_min100_max300.pkl'
    if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                print('loading induced graph...')
                graphs_list = pickle.load(f)
                print('Done!!!')
    else:
        print('Begin split_induced_graphs.')
        split_induced_graphs(data, folder_path, device, smallest_size=100, largest_size=300)
        with open(file_path, 'rb') as f:
            graphs_list = pickle.load(f)
    graphs_list = [graph.to(device) for graph in graphs_list]
    return graphs_list


args = get_args()

def main():
    args.pre_train_model_path = './dataspace/pre_trained_model/{}/{}.{}.128hidden_dim.pth'.format(args.pre_train_data, args.pre_train_type, args.gnn_type)

    if args.task == 'NodeTask' or args.task == 'FineTuneNodeTask':
        data, input_dim, output_dim = load4node(args.dataset_name)   
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = data.to(device)
        if args.prompt_type in ['All-in-one', 'GPF', 'GPF-plus']:
            graphs_list = load_induced_graph(args.dataset_name, data, device)  
        else:
            graphs_list = None 

    print("Dataset: {}, Pre-train Data: {}, GNN: {}, Pretrain: {}, Prompt: {}, ShotNum: {}, Seed: {}".format(args.dataset_name, args.pre_train_data, args.gnn_type, args.pre_train_type, args.prompt_type, args.shot_num, args.seed))

    if args.task == 'NodeTaskWeightedPATE':
        train_idx_all = []
        avg_centrality_score = []
        for i in range(args.number_of_teachers):
            teacher_idx = i
            tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
                        dataset_name = args.dataset_name, num_layer = args.num_layer,
                        gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
                        epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
                        batch_size = args.batch_size, seed = args.seed, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list, use_different_dataset=args.use_different_dataset, pre_train_data=args.pre_train_data, disable_dp = args.disable_dp, eps=args.eps, delta=args.delta, sample_rate=args.sample_rate, pate=args.pate, teacher_idx = teacher_idx, student_prompt=False, weighted_pate=args.weighted_pate)
            train_idx, centrality_score = tasker.pate_ensemble(header=None)
            train_idx_all.append(train_idx)
            avg_centrality_score.append(centrality_score)
        train_idx_all = torch.cat(train_idx_all, dim=0)
        avg_centrality_score = torch.stack(avg_centrality_score, dim=0)
        # save train_idx_all
        train_idx_save_path = './dataspace/TrainIdx/{}shot/{}_{}/seed_{}/{}_{}_{}.pt'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
        if not os.path.exists(os.path.split(train_idx_save_path)[0]):
                os.makedirs(os.path.split(train_idx_save_path)[0])
        torch.save(train_idx_all, train_idx_save_path)
        # save avg_centrality_score
        centrality_score_save_path = './dataspace/CentralityScore/{}shot/{}_{}/seed_{}/{}_{}_{}.pt'.format(args.shot_num, args.dataset_name, args.pre_train_data, args.seed, args.pre_train_type, args.prompt_type, args.gnn_type)
        if not os.path.exists(os.path.split(centrality_score_save_path)[0]):
                os.makedirs(os.path.split(centrality_score_save_path)[0])
        torch.save(avg_centrality_score, centrality_score_save_path)
    elif args.task == 'NodeTaskStudentPrompt':
        tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
                        dataset_name = args.dataset_name, num_layer = args.num_layer,
                        gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
                        epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
                        batch_size = args.batch_size, seed = args.seed, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list, use_different_dataset=args.use_different_dataset, pre_train_data=args.pre_train_data, disable_dp = args.disable_dp, eps=args.eps, delta=args.delta, sample_rate=args.sample_rate, pate=args.pate, teacher_idx=None, student_prompt=args.student_prompt, weighted_pate=args.weighted_pate)
        header = ['pre_train_type', 'pre_train_data', 'prompt_type', 'downstream_data', 'gnn', 'shot_num', 'seed', 'train_acc', 'train_loss', 'test_acc', 'test_loss', 'created_time']
        _, _ = tasker.pate_ensemble(header)
    
if __name__ == "__main__":
    main()



    




