import analyze_utils

# TODO always load args from disk, delete this dict.
ARGS_DICT = {
    "dpr_nq__msl32_beta": "--dataset_name nq --per_device_train_batch_size 128 --per_device_eval_batch_size 128 --max_seq_length 32 --model_name_or_path t5-base --embedder_model_name gtr_base --num_repeat_tokens 16 --embedder_no_grad True --exp_group_name mar17-baselines --learning_rate 0.0003 --freeze_strategy none --embedder_fake_with_zeros False --use_frozen_embeddings_as_input False --num_train_epochs 24 --max_eval_samples 500 --eval_steps 25000 --warmup_steps 100000 --bf16=1 --use_wandb=0",
    "gtr_nq__msl128_beta": "--dataset_name nq --per_device_train_batch_size 128 --per_device_eval_batch_size 128 --max_seq_length 128 --model_name_or_path t5-base --embedder_model_name gtr_base --num_repeat_tokens 16 --embedder_no_grad True --exp_group_name mar17-baselines --learning_rate 0.0003 --freeze_strategy none --embedder_fake_with_zeros False --use_frozen_embeddings_as_input False --num_train_epochs 24 --max_eval_samples 500 --eval_steps 25000 --warmup_steps 100000 --bf16=1 --use_wandb=0",
    # "gtr_nq__msl32_beta__correct": "--experiment corrector_encoder --per_device_train_batch_size 256 --per_device_eval_batch_size 256 --max_seq_length 32 --model_name_or_path t5-base --embedder_model_name gtr_base --num_repeat_tokens 16 --embedder_no_grad True --exp_group_name may19-corr-encoder --learning_rate 0.002 --freeze_strategy none --embedder_fake_with_zeros False --use_frozen_embeddings_as_input False --encoder_dropout_disabled False --decoder_dropout_disabled False --use_less_data -1 --num_train_epochs 60 --max_eval_samples 500 --eval_steps 25000 --warmup_steps 200000 --bf16=1 --use_lora=0 --use_wandb=1",
    # "openai_msmarco__msl128__100epoch": "--per_device_train_batch_size 128 --per_device_eval_batch_size 128 --max_seq_length 128 --model_name_or_path t5-base --embedder_model_name gtr_base --num_repeat_tokens 16 --embedder_no_grad True --learning_rate 0.0002 --freeze_strategy none --embedder_fake_with_zeros False --encoder_dropout_disabled False --decoder_dropout_disabled False --use_less_data 1000000 --num_train_epochs 100 --max_eval_samples 500 --eval_steps 50000 --warmup_steps 20000 --bf16=1 --use_lora=0 --use_wandb=0 --embedder_model_api text-embedding-ada-002 --use_frozen_embeddings_as_input True --exp_group_name jun3-openai-4gpu-ddp-3",
}   

# Dictionary mapping model names
CHECKPOINT_FOLDERS_DICT = {
    "dpr_nq__msl32_beta": "",
    "gtr_nq__msl128_beta": "",
    "gtr_nq__msl32_beta__correct": "",
    "gtr_nq__msl32_beta__correct__nofeedback": "",
    "clinicalbert_nq__msl32": "",
    "gtr_msmarco__msl128__100epoch": "",
    "openai_msmarco__msl32__100epoch": "",
    "openai_msmarco__msl32__100epoch__correct": "",
    "openai_msmarco__msl128__100epoch": "",
    "openai_gsm8k_100epoch":"",
    "openai_ag_news_100epoch":"",
    "openai_imbd_100epoch":"",
    "openai_sst2_100epoch":"/home/users/astar/ares/li_jing/pri2/dc/Jinghan/vec2text/saves/openai-inversion-sst2",
    "openai_msmarco__msl128__100epoch__correct": "",
    "openai_msmarco__msl128__200epoch__correct": "",
    "logits__gpt2": "",
    "t5-base___llama-7b___one-million-paired-instructions": "",
    "t5-base__llama-7b__one-million-paired-instructions": "",
    "t5_base__llama-7b__one-million-instructions__correct__70epoch": "",
    "t5-base___llama-7b___one-million-instructions__correct": ""
}


def load_experiment_and_trainer_from_alias(
    alias: str, max_seq_length: int = None, use_less_data: int = None
):  # -> trainers.InversionTrainer:
    try:
        args_str = ARGS_DICT.get(alias)
        checkpoint_folder = CHECKPOINT_FOLDERS_DICT[alias]
    except KeyError:
        print(f"{alias} not found in aliases.py, using as checkpoint folder")
        args_str = None
        checkpoint_folder = alias
    print(f"loading alias {alias} from {checkpoint_folder}...")
    experiment, trainer = analyze_utils.load_experiment_and_trainer(
        checkpoint_folder,
        args_str,
        do_eval=False,
        max_seq_length=max_seq_length,
        use_less_data=use_less_data,
    )
    return experiment, trainer


def load_model_from_alias(alias: str, max_seq_length: int = None):
    _, trainer = load_experiment_and_trainer_from_alias(
        alias, max_seq_length=max_seq_length
    )
    return trainer.model
