import torch
from prompt_graph.utils import get_morpher_args, load_yaml, get_morpher_zeroshot_args
from prompt_graph.utils import seed_everything, seed_torch
from prompt_graph.tasker import ZeroShot
from transformers import RobertaModel, RobertaTokenizer
import pdb
import os.path as osp

if __name__ == '__main__':
    args = get_morpher_zeroshot_args()
    config = load_yaml(osp.join('configurations_morpher', args.dataset_name + '_' + args.pretrain_method + '_' + args.gnn_type + '.yaml'))

    torch.use_deterministic_algorithms(True)
    seed_torch(config.seed)
    torch.cuda.set_device('cuda:'+ str(args.device) if torch.cuda.is_available() else 'cpu')
    # Load pre-trained model tokenizer (vocabulary)
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    # Load pre-trained model (weights)
    model = RobertaModel.from_pretrained('roberta-base').to(args.device)
    # Put the model in "evaluation" mode, meaning feed-forward operation.
    model.eval()

    if args.pre_train_model_path == 'None':
        args.pre_train_model_path = osp.join('pre_trained_gnn', args.dataset_name + '.' + args.pretrain_method + '.' + args.gnn_type + '.' + str(config.hid_dim) + 'hidden_dim.pth')

    a = args.pre_train_model_path

    tasker = ZeroShot(pre_train_model_path=args.pre_train_model_path, pretrain_method=args.pretrain_method, gnn_type=args.gnn_type, hid_dim=config.hid_dim, num_layer=config.num_layer, dataset_name=args.dataset_name, prompt_type=args.prompt_type, 
                   epochs=config.epochs, shot_num=config.shot_num, batch_size=config.batch_size, prompt_graph_token_num=config.prompt_graph_token_num, tokenizer=tokenizer, llm=model, device=args.device,
                   projector_lr=config.projector_lr, projector_weight_decay=config.projector_weight_decay, projector_tune_lr=config.projector_tune_lr, projector_tune_weight_decay=config.projector_tune_weight_decay,
                   pg_lr=config.pg_lr, pg_weight_decay=config.pg_weight_decay, text_prompt_lr=config.text_prompt_lr, text_prompt_weight_decay=config.text_prompt_weight_decay,
                   projector_dropout_ratio=config.projector_dropout_ratio, temperature=config.temperature, text_prompt_start_vocab=config.text_prompt_start_vocab,
                   projector_epochs=config.projector_epochs, projector_train_eval_diff_threshold=config.projector_train_eval_diff_threshold, projector_train_modular=config.projector_train_modular,
                   projector_tune_epochs=config.projector_tune_epochs, prompt_tune_epochs=config.prompt_tune_epochs,
                   train_val_acc_diff_tol=config.train_val_acc_diff_tol, val_acc_threshold=config.val_acc_threshold, warmup_epochs=config.warmup_epochs, random_seed=config.seed, train_good_threshold=config.train_good_threshold)
    
    tasker.run()