from prompt_graph.tasker import NodeTask, GraphTask
from prompt_graph.utils import seed_everything, load_yaml
from torchsummary import summary
from prompt_graph.utils import print_model_parameters
from prompt_graph.utils import  get_args
import os.path as osp
import pdb

if __name__ == '__main__':
    args = get_args()
    if args.transfer:
        config = load_yaml(osp.join('configurations_improvedAIO', 'transfer_from_' + args.source + '_to_' + args.dataset_name + '_' + args.pretrain_method + '_' + args.gnn_type + '.yaml'))
    else:
        if args.task == 'GraphTask':
            config = load_yaml(osp.join('configurations_improvedAIO', args.dataset_name + '_' + args.pretrain_method + '_' + args.gnn_type + '.yaml'))
        else:
            config = load_yaml(osp.join('configurations_improvedAIO', args.dataset_name + '_' + args.pretrain_method + '_' + args.gnn_type + '_' + args.task + '.yaml'))
    seed_everything(config.seed)

    # if (not args.supervised) and args.pre_train_model_path == 'None':
    #     args.pre_train_model_path = osp.join('pre_trained_gnn', args.dataset_name + '.' + args.pretrain_method + '.' + args.gnn_type + '.' + str(config.hid_dim) + 'hidden_dim.pth')


    if (not args.supervised) and args.pre_train_model_path == 'None':
        if not args.transfer:
            args.pre_train_model_path = osp.join('pre_trained_gnn', args.dataset_name + '.' + args.pretrain_method + '.' + args.gnn_type + '.' + str(config.hid_dim) + 'hidden_dim.pth')
        else:   # transfer
            args.pre_train_model_path = osp.join('pre_trained_gnn', 'transfer_' + args.source + '.' + args.pretrain_method + '.' + args.gnn_type + '.' + str(config.hid_dim) + 'hidden_dim_featuredim' + str(config.feature_dim) + '.pth')



    if args.supervised or args.finetune:
        epochs = args.epochs
    else:
        epochs = config.epochs

    # if args.task == 'NodeTask':
    #     tasker = NodeTask(pre_train_model_path = args.pre_train_model_path, 
    #                     dataset_name = args.dataset_name, num_layer = args.num_layer, gnn_type = args.gnn_type, prompt_type = args.prompt_type, epochs = args.epochs, shot_num = args.shot_num)


    # if args.task == 'GraphTask':
    tasker = GraphTask(pre_train_model_path = args.pre_train_model_path, 
                    dataset_name = args.dataset_name, num_layer = config.num_layer, gnn_type = args.gnn_type, 
                    prompt_type = args.prompt_type, epochs = epochs, shot_num = config.shot_num, device = args.device, 
                    config=config, finetune=args.finetune, task=args.task, transfer=args.transfer)
        
    tasker.run()