from prompt_graph.tasker import NodeTask
from prompt_graph.utils import seed_everything
# from torchsummary import summary
from prompt_graph.utils import  get_args
from prompt_graph.data import load4node, split_induced_graphs
import pickle
import os
import ipdb
import torch


def load_induced_graph(dataset_name, data, device):

    folder_path = './dataspace/induced_graph/' + dataset_name
    if not os.path.exists(folder_path):
            os.makedirs(folder_path)

    file_path = folder_path + '/induced_graph_min100_max300.pkl'
    if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                print('loading induced graph...')
                graphs_list = pickle.load(f)
                print('Done!!!')
    else:
        print('Begin split_induced_graphs.')
        split_induced_graphs(data, folder_path, device, smallest_size=100, largest_size=300)
        with open(file_path, 'rb') as f:
            graphs_list = pickle.load(f)
    graphs_list = [graph.to(device) for graph in graphs_list]
    return graphs_list


args = get_args()

def main():
    args.pre_train_model_path = './dataspace/pre_trained_model/{}/{}.{}.128hidden_dim.pth'.format(args.pre_train_data, args.pre_train_type, args.gnn_type)

    if args.task == 'NodeTask' or args.task == 'FineTuneNodeTask':
        data, input_dim, output_dim = load4node(args.dataset_name)   
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = data.to(device)
        if args.prompt_type in ['All-in-one', 'GPF', 'GPF-plus']:
            graphs_list = load_induced_graph(args.dataset_name, data, device)  
        else:
            graphs_list = None 

    print("Dataset: {}, Pre-train Data: {}, GNN: {}, Pretrain: {}, Prompt: {}, ShotNum: {}, Seed: {}".format(args.dataset_name, args.pre_train_data, args.gnn_type, args.pre_train_type, args.prompt_type, args.shot_num, args.seed))

    if args.task == 'NodeTask':
        tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
                        dataset_name = args.dataset_name, num_layer = args.num_layer,
                        gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
                        epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
                        batch_size = args.batch_size, seed = args.seed, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list, pre_train_data=args.pre_train_data, dpsgd = args.dpsgd, eps=args.eps, delta=args.delta, sample_rate=args.sample_rate, pate=args.pate, teacher_idx=None, student_prompt=False, weighted_pate=False)

    elif args.task == 'FineTuneNodeTask':
        args.prompt_type = 'None'
        tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
                        dataset_name = args.dataset_name, num_layer = args.num_layer,
                        gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
                        epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
                        batch_size = args.batch_size, seed = args.seed, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list, pre_train_data=args.pre_train_data, dpsgd = args.dpsgd, eps=args.eps, delta=args.delta, sample_rate=args.sample_rate, pate=args.pate, teacher_idx=None, student_prompt=False, weighted_pate=False)

    # 1. train target prompt and shadow prompt
    if args.task == 'NodeTask':
        tasker.run()
        #wandb.log({'Final True Accuracy': test_acc, 'Macro F1 Score': f1, 'AUROC': roc, 'AUPRC': prc})
    elif args.task == 'FineTuneNodeTask':
        tasker.run()
    

if __name__ == "__main__":
    main()



    




