from transformers import (
    AutoTokenizer,
    GPTNeoForCausalLM,
    AutoConfig,
    AutoModelForCausalLM,
    LongT5EncoderModel,
    LongT5Model,
)
import torch
import argparse
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import f1_score
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from models import TransformerLanguageModel


from sklearn.decomposition import PCA, SparsePCA
from models import TransformerModel
from typing import List, Dict, Tuple, Union
from reasoning_utils import (
    compute_pca,
    get_non_corrup_embeds,
    compute_pca_non_corrupt,
    get_pca_embeds_repeat,
    get_pca_embeds,
    get_pca_embeds_long_context,
    get_pca_embeds_naive_repeat,
)


import os


def load_model(
    model_name, path_to_synthetic_model=None, frozen_model=False, ordering=False
):
    config = AutoConfig.from_pretrained(model_name)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token="<|pad|>",
    )
    tokenizer.truncation_side = "left"

    if "t5" not in model_name:
        # should work for all generative LM models
        trained_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
    else:
        trained_model = LongT5Model.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
        base_model = LongT5Model.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
    # load weights
    if path_to_synthetic_model is not None:
        synthetic_model = torch.load(path_to_synthetic_model)

        new_state_dict = {}
        if "pythia" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.gpt_neox"):
                    k = k.replace("_backbone.gpt_neox", "gpt_neox")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_out"):
                    k = k.replace("_backbone.embed_out", "embed_out")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_in"):
                    k = k.replace("_backbone.embed_in", "embed_in")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        elif "opt" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.model"):
                    k = k.replace("_backbone.model", "model")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        else:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.transformer"):
                    k = k.replace("_backbone.transformer", "transformer")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v

        # load state dict
        trained_model.load_state_dict(new_state_dict, strict=False)

    trained_model.eval()
    base_model.eval()

    return (
        trained_model.cuda(),
        base_model.cuda(),
        tokenizer,
    )


# Data load function
def load_data(data_path, dataset="sms", input_key="sms", seed=42):
    # load training data
    train_set_path = os.path.join(data_path, f"{dataset}/train_samples_s{seed}.csv")
    training_set = pd.read_csv(train_set_path)
    X_train = training_set[input_key].tolist()
    y_train = training_set["label"].tolist()

    # load test data
    # load test data
    if "ag_news" in dataset or "dbpedia" in dataset or "civil_comments" in dataset:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_bal.csv")
    else:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_orig.csv")
    test_set = pd.read_csv(test_set_path)
    X_test = test_set[input_key].tolist()
    y_test = test_set["label"].tolist()

    # return
    return (X_train, y_train), (X_test, y_test)


def generate_in_context_example(
    X_train, y_train, template, seed, order=None, text_threshold=100
):
    samples = list(zip(X_train, y_train))
    # generate a random permutation of 0 to 128
    if not order:
        # set numpy random seed
        np.random.seed(seed)
        order = np.random.permutation(len(samples))

    in_context_example = ""
    ys = []
    for i in order:
        sample = samples[int(i)]
        if sample[1] == 1:
            label = "positive"
            ys.append(sample[1])
        else:
            label = "negative"
            ys.append(sample[1])
        in_context_example += template.format(
            sentence=sample[0][0:text_threshold].strip(), label=label
        )

    return in_context_example, ys


def get_template(template):
    if template == "sentence_label":
        train_template = f"Sentence: {{sentence:}}\nLabel: {{label:}}\n"
        test_template = f"Sentence: {{sentence:}}\nLabel:"
    elif template == "colon_label":
        train_template = f"{{sentence:}} : {{label:}} , "
        test_template = f"{{sentence:}} :"
    return train_template, test_template


def evaluate_original_or_tuned(
    model,
    tokenizer,
    dataset,
    data_path,
    input_key,
    seed=42,
    prompt_format="sentence_label",
    random_seed=None,
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    text_threshold=100,
):
    """
    Evaluates adaptor model and LR solver model on dataset

    embed_model: the model used to embed the in context examples
    inference_head: torch module used for inference
    """
    sigmoid = torch.nn.Sigmoid()
    map_label = {0: "negative", 1: "positive"}
    postive_token_id_no_space = tokenizer("positive").input_ids[0]
    negative_token_id_no_space = tokenizer("negative").input_ids[0]
    positive_token_id_space = tokenizer(" positive").input_ids[0]
    negative_token_id_space = tokenizer(" negative").input_ids[0]

    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key, seed
    )

    results = {0: {}}
    results_pca = {0: {}}

    with torch.no_grad():
        for k in k_range:
            (
                gt_label,
                predicted_label,
                original_text,
                predicted_text,
                predicted_scores,
            ) = ([], [], [], [], [])
            (
                gt_label_pca,
                predicted_label_pca,
                original_text_pca,
                predicted_text_pca,
                predicted_scores_pca,
            ) = ([], [], [], [], [])
            results[0][k] = None
            train_template, test_template = get_template(prompt_format)
            in_context_example, ys = generate_in_context_example(
                X_train[0:k],
                y_train[0:k],
                template=train_template,
                seed=(seed if random_seed is None else random_seed),
                order=None,
                text_threshold=text_threshold,
            )

            for test_idx, (text, label) in tqdm(enumerate(zip(X_test, y_test))):
                current_sample = test_template.format(
                    sentence=text[0:text_threshold].strip(), label=""
                )

                prompt = f"{in_context_example}{current_sample}"
                input_ids = tokenizer(
                    prompt, max_length=2048, truncation=True, return_tensors="pt"
                ).input_ids
                curr_example = current_sample
                current_ids = tokenizer(
                    curr_example,
                    max_length=2048,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids

                sample_outputs = model.generate(
                    input_ids.cuda(),
                    do_sample=False,
                    temperature=0.0,
                    max_new_tokens=1,
                    output_scores=True,
                    return_dict_in_generate=True,
                )

                pred_text = model.decode(
                    sample_outputs["sequences"][0], skip_special_tokens=True
                )

                logits = torch.softmax(sample_outputs["scores"][0], axis=-1)
                pos_score_space = logits[:, positive_token_id_space].item()
                pos_score_no_space = logits[:, postive_token_id_no_space].item()
                neg_score_space = logits[:, negative_token_id_space].item()
                neg_score_no_space = logits[:, negative_token_id_no_space].item()

                pred_label = pred_text.split(":")[-1].strip()
                # extract the predicted sentiment

                # append results
                if label in map_label:
                    gt_label.append(map_label[label])
                else:
                    gt_label.append(label)
                predicted_label.append(pred_label)
                original_text.append(text)
                predicted_text.append(pred_text)
                predicted_scores.append(
                    (
                        pos_score_space,
                        pos_score_no_space,
                        neg_score_space,
                        neg_score_no_space,
                    )
                )
            results[0][k] = {
                "original_text": original_text,
                "predicted_label": predicted_label,
                "gt_label": gt_label,
                "predicted_text": predicted_text,
                "predicted_scores": predicted_scores,
                "accuracy": sum(
                    [
                        1 if x == y else 0
                        for x, y in zip(gt_label_pca, predicted_label_pca)
                    ]
                )
                / len(gt_label_pca),
            }
    return results


def evaluate_adaptor(
    inference_head,
    embed_model,
    embed_tokenizer,
    num_pca_components=8,  # for pca
    seed=42,
    input_key="sms",
    data_path="./data",
    dataset="sms",
    prompt_format="sentence_label",
    random_seed=None,
    # corrupted_embeds=False,
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    text_threshold=100,
    corrupted_type=None,
):
    """
    Evaluates adaptor model and LR solver model on dataset

    embed_model: the model used to embed the in context examples
    inference_head: torch module used for inference
    """
    sigmoid = torch.nn.Sigmoid()
    map_label = {0: "negative", 1: "positive"}

    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key, seed
    )

    results = {0: {}}
    results_pca = {0: {}}

    with torch.no_grad():
        for k in k_range:
            (
                gt_label,
                predicted_label,
                original_text,
                predicted_text,
                predicted_scores,
            ) = ([], [], [], [], [])
            (
                gt_label_pca,
                predicted_label_pca,
                original_text_pca,
                predicted_text_pca,
                predicted_scores_pca,
            ) = ([], [], [], [], [])
            results[0][k] = None

            if corrupted_type == "corrupted_loo":
                print("corrupted_loo")
                (
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                ) = get_pca_embeds_repeat(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train[0:k],
                    y_train[0:k],
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()
            elif corrupted_type == "corrupted_base":
                print("corrupted_base")
                (X_tr_embed_cor, X_tst_embed_cor, y_tr_shuffle,) = get_pca_embeds(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train[0:k],
                    y_train[0:k],
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()

            elif corrupted_type == "corrupted_long":
                print("corrupted_long")
                (
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                ) = get_pca_embeds_long_context(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train,
                    y_train,
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()

            else:
                print("non corrupted")
                X_tr_embed_cor, X_tst_embed_cor, _ = get_non_corrup_embeds(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train[0:k],
                    y_train[0:k],
                    k,
                    thresh=text_threshold,
                )
                # compute pca
                (
                    X_tr_pca_cor_white,
                    X_tst_pca_cor_white,
                    pca_pred,
                ) = compute_pca_non_corrupt(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_train[0:k],
                    k,
                    y_test,
                )
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_train[0:k]).float()
                y_test_torch = torch.Tensor(y_test).float()

            for test_idx, (text, label) in tqdm(enumerate(zip(X_test, y_test))):
                xs = torch.cat(
                    [
                        X_tr_pca_cor_white_torch,
                        X_tst_pca_cor_white_torch[test_idx, :].unsqueeze(0),
                    ],
                    dim=0,
                ).unsqueeze(0)
                ys = torch.cat(
                    [y_tr_shuffle_torch, y_test_torch[test_idx : test_idx + 1]],
                    dim=0,
                ).unsqueeze(0)

                # conver xs from (1, n, d) to (1, n, 2d)
                # TODO: make this more general
                # xs = torch.cat([xs, xs], dim=-1)
                outs = inference_head(xs.cuda(), ys.cuda())
                pred = sigmoid(outs)[0][-1].item()

                if pred >= 0.5:
                    pred_text = "positive"
                    pred_label = "positive"
                else:
                    pred_text = "negative"
                    pred_label = "negative"

                predicted_scores.append(pred)
                predicted_label.append(pred_label)
                original_text.append(text)
                predicted_text.append(pred_text)

                if label in map_label:
                    gt_label.append(map_label[label])
                else:
                    gt_label.append(label)

                if pca_pred[test_idx] == 0:
                    pred_text_pca = "negative"
                    pred_label_pca = "negative"
                else:
                    pred_text_pca = "positive"
                    pred_label_pca = "positive"

                predicted_scores_pca.append(0)
                predicted_label_pca.append(pred_label_pca)
                original_text_pca.append(text)
                predicted_text_pca.append(pred_text_pca)
                if label in map_label:
                    gt_label_pca.append(map_label[label])
                else:
                    gt_label_pca.append(label)

            results[0][k] = {
                "original_text": original_text,
                "predicted_label": predicted_label,
                "gt_label": gt_label,
                "predicted_scores": predicted_scores,
                "accuracy": sum(
                    [1 if x == y else 0 for x, y in zip(gt_label, predicted_label)]
                )
                / len(gt_label),
            }
            results_pca[0][k] = {
                "original_text": original_text_pca,
                "predicted_label": predicted_label_pca,
                "gt_label": gt_label_pca,
                "predicted_scores": predicted_scores_pca,
                "accuracy": sum(
                    [
                        1 if x == y else 0
                        for x, y in zip(gt_label_pca, predicted_label_pca)
                    ]
                )
                / len(gt_label_pca),
            }
            print(f"ADAPTOR: {k}, {seed}, {results[0][k]['accuracy']}")
            print(f"LR Solver: {k}, {seed}, {results_pca[0][k]['accuracy']}")

    return results, results_pca


def load_model_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token="<|pad|>",
    )
    tokenizer.truncation_side = "left"
    model = GPTNeoForCausalLM.from_pretrained(
        model_name,
    ).cuda()
    model.eval()
    return model, tokenizer


def eval_adaptor(
    n_dims,
    n_positions,
    num_pca_components,
    path_to_pretrained_adaptor,
    dataset,
    input_key,
    data_path="./data",
    embed_model_name="EleutherAI/gpt-neo-125M",
    path_to_pretrained_model=None,
    # corrupted_embeds=True,
    prompt_format="colon_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
    corrupted_type=None,
):
    ### load in adaptor model
    print("loading adaptor model")
    adaptor_model = TransformerModel(
        n_dims=n_dims, n_positions=n_positions, n_embd=256, n_head=8, n_layer=12, n_y=1
    )
    aweights = torch.load(path_to_pretrained_adaptor)
    adaptor_model.load_state_dict(aweights, strict=False)
    adaptor_model = adaptor_model.cuda()

    ### Load in base embed model
    print("loading embed model")
    if path_to_pretrained_model is not None:
        ### here we load weights for checkpoint in
        embed_model, _, embed_tokenizer = load_model(
            path_to_pretrained_model, model_name=embed_model_name
        )
    else:
        embed_tokenizer = AutoTokenizer.from_pretrained(
            embed_model_name,
            pad_token="<|pad|>",
        )
        embed_tokenizer.truncation_side = "left"

        if "t5" in embed_model_name:
            embed_model = LongT5EncoderModel.from_pretrained(
                embed_model_name,
            ).cuda()
        else:
            embed_model = AutoModelForCausalLM.from_pretrained(
                embed_model_name,
            ).cuda()
        embed_model.eval()

    ### Call eval function
    final_results_adaptor = {}
    final_results_lrsolver = {}
    for rs in [9]:
        for seed in seeds:
            print("evaling trained model")
            results_adaptor, results_lrsolver = evaluate_adaptor(
                inference_head=adaptor_model,
                embed_model=embed_model,
                embed_tokenizer=embed_tokenizer,
                num_pca_components=num_pca_components,  # for pca
                seed=seed,
                input_key=input_key,
                data_path=data_path,
                dataset=dataset,
                prompt_format=prompt_format,
                random_seed=seed,
                k_range=k_range,
                text_threshold=text_threshold,
                corrupted_type=corrupted_type,
            )
            final_results_lrsolver[seed] = results_lrsolver
            final_results_adaptor[seed] = results_adaptor

        run_id = path_to_pretrained_adaptor.split("/")[-2]
        checkpoint = path_to_pretrained_adaptor.split("/")[-1].split(".")[0]

    model_name_split = embed_model_name.split("/")[-1]
    save_path_adaptor = (
        f"{save_dir}/{dataset}/{model_name_split}/adaptor/{run_id}/{checkpoint}/"
    )
    save_path_lrsolver = (
        f"{save_dir}/{dataset}/{model_name_split}/lrsolver/{run_id}/{checkpoint}/"
    )
    if not os.path.exists(save_path_adaptor):
        os.makedirs(save_path_adaptor)
    if not os.path.exists(save_path_lrsolver):
        os.makedirs(save_path_lrsolver)

    file_name = f"{save_path_adaptor}/Adaptor_{corrupted_type}.pkl"
    pickle.dump(final_results_adaptor, open(file_name, "wb"))
    file_name = f"{save_path_lrsolver}/LRSolver_{corrupted_type}.pkl"
    pickle.dump(final_results_lrsolver, open(file_name, "wb"))


def eval_base(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="sentence_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
):
    ### Loading base model
    _, model, tokenizer = load_model(
        path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
    )
    final_results_base = {}
    for rs in [9]:
        for seed in seeds:  # 9, 204, 405, 9205, 2020]:
            print("evaling trained model")
            results = evaluate_original_or_tuned(
                model,
                tokenizer,
                dataset,
                data_path=data_path,
                input_key=key,
                seed=seed,
                prompt_format=prompt_format,
                random_seed=None,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_base[seed] = results

        run_id = path_to_synthetic_model.split("/")[-2]
        checkpoint = path_to_synthetic_model.split("/")[-1].split(".")[0]
    file_name = f"{save_dir}/{model_name}_eval_{dataset}_rs_{rs}.pkl".replace("/", "-")
    pickle.dump(final_results_base, open(file_name, "wb"))


def eval_lrtuned(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="colon_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=10,
):
    ### Loading base model
    _, model, tokenizer = load_model(
        path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
    )
    final_results_base = {}
    for rs in [9]:
        for seed in [42, 69, 128, 512, 1024]:
            print("evaling trained model")
            results = evaluate_original_or_tuned(
                model,
                tokenizer,
                dataset,
                data_path=data_path,
                input_key=key,
                seed=seed,
                prompt_format=prompt_format,
                random_seed=None,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_base[seed] = results

        run_id = path_to_synthetic_model.split("/")[-2]
        checkpoint = path_to_synthetic_model.split("/")[-1].split(".")[0]

    file_name = f"{save_dir}/LRTuned_{dataset}_runid_{run_id}_ckpt_{checkpoint}_rs_{rs}_{prefix}.pkl"

    pickle.dump(final_results_base, open(file_name, "wb"))


if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="EleutherAI/gpt-neo-125m",
    )
    parser.add_argument(
        "--path_to_adaptor_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--path_to_synthetic_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="sms",
    )
    parser.add_argument(
        "--key",
        type=str,
        default="text",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="./data",
    )
    parser.add_argument(
        "--prompt_format",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--k_range",
        type=str,
        default="[4, 8, 16, 32, 48, 64, 96, 128]",
    )
    parser.add_argument(
        "--seeds",
        type=str,
        default="[42, 69, 128, 512, 1024]",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--n_dims",
        type=int,
        default=16,
    )
    parser.add_argument(
        "--n_positions",
        type=int,
        default=64,
    )

    parser.add_argument(
        "--corrupted_type",
        type=str,
        default=None,
        # add accepted values
        choices=[
            "corrupted_base",
            "corrupted_loo",
            "non_corrupted",
            "corrupted_long",
            "corrupted_long_2x",
        ],
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./outputs",
    )
    parser.add_argument(
        "--text_threshold",
        type=int,
        default=100,
    )
    parser.add_argument("--num_pca_components", type=int, default=8)
    parser.add_argument("--eval_adaptor", action="store_true")
    parser.add_argument("--eval_base", action="store_true")
    parser.add_argument("--eval_lrtuned", action="store_true")

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if args.eval_adaptor:
        eval_adaptor(
            embed_model_name=args.base_model_name,
            path_to_pretrained_adaptor=args.path_to_adaptor_model,
            dataset=args.dataset,
            input_key=args.key,
            data_path=args.data_path,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            n_dims=args.n_dims,
            n_positions=args.n_positions,
            num_pca_components=args.num_pca_components,
            prompt_format=args.prompt_format,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
            seeds=eval(args.seeds),
            corrupted_type=args.corrupted_type,
        )
    if args.eval_base:
        eval_base(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
            seeds=eval(args.seeds),
        )
    if args.eval_lrtuned:
        eval_lrtuned(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
        )
