from transformers import (
    AutoTokenizer,
    GPTNeoForCausalLM,
    AutoConfig,
    AutoModelForCausalLM,
    LongT5EncoderModel,
    LongT5Model,
)
import torch
import argparse
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import f1_score
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from models import TransformerLanguageModel
from sklearn.decomposition import PCA, SparsePCA
from models import TransformerModel
from typing import List, Dict, Tuple, Union
from reasoning_utils import (
    compute_pca,
    get_non_corrup_embeds,
    compute_pca_non_corrupt,
    get_pca_embeds_repeat,
    get_pca_embeds,
    get_pca_embeds_long_context,
    get_pca_embeds_naive_repeat,
    get_non_corrup_embeds_audio,
    get_non_corrup_embeds_image,
)
from datasets import load_dataset
from transformers import ViTModel, WhisperModel, AutoFeatureExtractor


import os


def load_model(
    model_name, path_to_synthetic_model=None, frozen_model=False, ordering=False
):
    config = AutoConfig.from_pretrained(model_name)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token="<|pad|>",
    )
    tokenizer.truncation_side = "left"

    if "t5" not in model_name:
        # should work for all generative LM models
        trained_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
    else:
        trained_model = LongT5Model.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
        base_model = LongT5Model.from_pretrained(
            pretrained_model_name_or_path=model_name,
            config=config,
        )
    # load weights
    if path_to_synthetic_model is not None:
        synthetic_model = torch.load(path_to_synthetic_model)

        new_state_dict = {}
        if "pythia" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.gpt_neox"):
                    k = k.replace("_backbone.gpt_neox", "gpt_neox")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_out"):
                    k = k.replace("_backbone.embed_out", "embed_out")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_in"):
                    k = k.replace("_backbone.embed_in", "embed_in")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        elif "opt" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.model"):
                    k = k.replace("_backbone.model", "model")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        else:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.transformer"):
                    k = k.replace("_backbone.transformer", "transformer")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v

        # load state dict
        trained_model.load_state_dict(new_state_dict, strict=False)

    trained_model.eval()
    base_model.eval()

    return (
        trained_model.cuda(),
        base_model.cuda(),
        tokenizer,
    )


# Data load function
def load_data(data_path, dataset="sms", input_key="sms", seed=42):
    # load training data
    train_set_path = os.path.join(data_path, f"{dataset}/train_samples_s{seed}.csv")
    training_set = pd.read_csv(train_set_path)
    X_train = training_set[input_key].tolist()
    y_train = training_set["label"].tolist()

    # load test data
    if "ag_news" in dataset or "dbpedia" in dataset or "civil_comments" in dataset:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_bal.csv")
    else:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_orig.csv")
    test_set = pd.read_csv(test_set_path)
    X_test = test_set[input_key].tolist()
    y_test = test_set["label"].tolist()

    # return
    return (X_train, y_train), (X_test, y_test)


def load_data_mm(data_path, dataset="sms", input_key="sms", seed=42):
    if dataset == "cifar10":
        dataset = load_dataset("cifar10")
        key = "img"
        # filter dataset_train to only include samples with label 0,1
        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [0, 2])  # [0,2]
            .map(lambda example: {"label": 0 if example["label"] == 2 else 1})
        )
        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [0, 2])
            .map(lambda example: {"label": 0 if example["label"] == 2 else 1})
        )
    elif dataset == "mnist":
        input_key = "image"
        dataset = load_dataset("mnist")

        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [1, 5])
            .map(lambda example: {"label": 0 if example["label"] == 1 else 1})
        )

        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [1, 5])
            .map(lambda example: {"label": 0 if example["label"] == 1 else 1})
        )

    elif dataset == "speech_commands":
        dataset = load_dataset("speech_commands", "v0.01")
        # filter dataset_train to only include samples with label [8,9] and map to [0,1]
        dataset_train = (
            dataset["train"]
            .filter(lambda example: example["label"] in [2, 3])
            .map(lambda example: {"label": 0 if example["label"] == 2 else 1})
        )
        dataset_test = (
            dataset["test"]
            .filter(lambda example: example["label"] in [2, 3])
            .map(lambda example: {"label": 0 if example["label"] == 2 else 1})
        )

    dataset_train_1 = (
        dataset_train.filter(lambda example: example["label"] == 1)
        .shuffle(seed=seed)
        .select(range(128))
    )
    # get k samples from dataset_test where label = 0
    dataset_train_0 = (
        dataset_train.filter(lambda example: example["label"] == 0)
        .shuffle(seed=seed)
        .select(range(128))
    )

    return (
        dataset_train_1[input_key],
        dataset_train_1["label"],
        dataset_train_0[input_key],
        dataset_train_0["label"],
    ), (dataset_test[input_key], dataset_test["label"])


def generate_in_context_example(
    X_train, y_train, template, seed, order=None, text_threshold=100
):
    samples = list(zip(X_train, y_train))
    # generate a random permutation of 0 to 128
    if not order:
        # set numpy random seed
        np.random.seed(seed)
        order = np.random.permutation(len(samples))

    in_context_example = ""
    ys = []
    for i in order:
        sample = samples[int(i)]
        if sample[1] == 1:
            label = "positive"
            ys.append(sample[1])
        else:
            label = "negative"
            ys.append(sample[1])
        in_context_example += template.format(
            sentence=sample[0][0:text_threshold].strip(), label=label
        )

    return in_context_example, ys


def get_template(template):
    if template == "sentence_label":
        train_template = f"Sentence: {{sentence:}}\nLabel: {{label:}}\n"
        test_template = f"Sentence: {{sentence:}}\nLabel:"
    elif template == "colon_label":
        train_template = f"{{sentence:}} : {{label:}} , "
        test_template = f"{{sentence:}} :"
    return train_template, test_template


def evaluate_adaptor(
    inference_head,
    embed_model,
    embed_tokenizer,
    num_pca_components=8,  # for pca
    seed=42,
    input_key="sms",
    data_path="./data",
    dataset="sms",
    prompt_format="sentence_label",
    random_seed=None,
    # corrupted_embeds=False,
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    text_threshold=100,
    corrupted_type=None,
    domain=None,
):
    """
    Evaluates adaptor model and LR solver model on dataset

    embed_model: the model used to embed the in context examples
    inference_head: torch module used for inference
    """
    sigmoid = torch.nn.Sigmoid()
    map_label = {0: "negative", 1: "positive"}

    (X_train_1, y_train_1, X_train_0, y_train_0), (X_test, y_test) = load_data_mm(
        data_path, dataset, input_key, seed
    )

    results = {0: {}}
    results_pca = {0: {}}

    with torch.no_grad():
        for k in k_range:
            (
                gt_label,
                predicted_label,
                original_text,
                predicted_text,
                predicted_scores,
            ) = ([], [], [], [], [])
            (
                gt_label_pca,
                predicted_label_pca,
                original_text_pca,
                predicted_text_pca,
                predicted_scores_pca,
            ) = ([], [], [], [], [])
            results[0][k] = None

            if domain == "image":
                x_k = X_train_0[0 : int(k / 2)] + X_train_1[0 : int(k / 2)]
                y_k = y_train_0[0 : int(k / 2)] + y_train_1[0 : int(k / 2)]
                print("image")
                (
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                ) = get_non_corrup_embeds_image(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    x_k,
                    y_k,
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor.cpu(),
                    X_tst_embed_cor.cpu(),
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()
            elif domain == "audio":
                x_k = X_train_0[0 : int(k / 2)] + X_train_1[0 : int(k / 2)]
                y_k = y_train_0[0 : int(k / 2)] + y_train_1[0 : int(k / 2)]
                print("audio")
                (
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                ) = get_non_corrup_embeds_audio(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    x_k,
                    y_k,
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor.cpu(),
                    X_tst_embed_cor.cpu(),
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()

            for test_idx, (text, label) in tqdm(enumerate(zip(X_test, y_test))):
                xs = torch.cat(
                    [
                        X_tr_pca_cor_white_torch,
                        X_tst_pca_cor_white_torch[test_idx, :].unsqueeze(0),
                    ],
                    dim=0,
                ).unsqueeze(0)
                ys = torch.cat(
                    [y_tr_shuffle_torch, y_test_torch[test_idx : test_idx + 1]],
                    dim=0,
                ).unsqueeze(0)

                # conver xs from (1, n, d) to (1, n, 2d)
                # TODO: make this more general
                # if domain == "image":
                xs = torch.cat([xs, xs], dim=-1)
                outs = inference_head(xs.cuda(), ys.cuda())
                pred = sigmoid(outs)[0][-1].item()

                if pred >= 0.5:
                    pred_text = "positive"
                    pred_label = "positive"
                else:
                    pred_text = "negative"
                    pred_label = "negative"
                predicted_scores.append(pred)
                predicted_label.append(pred_label)
                original_text.append(text)
                predicted_text.append(pred_text)
                if label in map_label:
                    gt_label.append(map_label[label])
                else:
                    gt_label.append(label)

                if pca_pred[test_idx] == 0:
                    pred_text_pca = "negative"
                    pred_label_pca = "negative"
                else:
                    pred_text_pca = "positive"
                    pred_label_pca = "positive"
                predicted_scores_pca.append(0)
                predicted_label_pca.append(pred_label_pca)
                original_text_pca.append(text)
                predicted_text_pca.append(pred_text_pca)
                if label in map_label:
                    gt_label_pca.append(map_label[label])
                else:
                    gt_label_pca.append(label)

            results[0][k] = {
                "predicted_label": predicted_label,
                "gt_label": gt_label,
                "predicted_scores": predicted_scores,
                "accuracy": sum(
                    [1 if x == y else 0 for x, y in zip(gt_label, predicted_label)]
                )
                / len(gt_label),
            }

            results_pca[0][k] = {
                "predicted_label": predicted_label_pca,
                "gt_label": gt_label_pca,
                "predicted_scores": predicted_scores_pca,
                "accuracy": sum(
                    [
                        1 if x == y else 0
                        for x, y in zip(gt_label_pca, predicted_label_pca)
                    ]
                )
                / len(gt_label_pca),
            }

    return results, results_pca


def load_model_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token="<|pad|>",
    )
    tokenizer.truncation_side = "left"
    model = GPTNeoForCausalLM.from_pretrained(
        model_name,
    ).cuda()
    model.eval()
    return model, tokenizer


def eval_adaptor(
    n_dims,
    n_positions,
    num_pca_components,
    path_to_pretrained_adaptor,
    dataset,
    input_key,
    data_path="./data",
    embed_model_name="EleutherAI/gpt-neo-125M",
    path_to_pretrained_model=None,
    prompt_format="colon_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
    corrupted_type=None,
    domain=None,
):
    ### load in adaptor model
    print("loading adaptor model")
    adaptor_model = TransformerModel(
        n_dims=n_dims, n_positions=n_positions, n_embd=256, n_head=8, n_layer=12, n_y=1
    )
    aweights = torch.load(path_to_pretrained_adaptor)
    adaptor_model.load_state_dict(aweights, strict=False)
    adaptor_model = adaptor_model.cuda()

    ### Load in base embed model
    print("loading embed model")
    if path_to_pretrained_model is not None:
        ### here we load weights for checkpoint in
        embed_model, _, embed_tokenizer = load_model(
            path_to_pretrained_model, model_name=embed_model_name
        )
    else:
        if not domain:
            assert "domain must be specified"
        elif domain == "image":
            input_key = "img"
            embed_model = ViTModel.from_pretrained(
                embed_model_name,
            ).cuda()
            embed_tokenizer = None
        elif domain == "audio":
            input_key = "audio"
            embed_model = WhisperModel.from_pretrained(
                embed_model_name,
            ).cuda()
            embed_tokenizer = AutoFeatureExtractor.from_pretrained(
                embed_model_name,
            )

        embed_model.eval()

    ### Call eval function
    final_results_adaptor = {}
    final_results_lrsolver = {}
    for rs in [9]:
        for seed in seeds:
            print("evaling trained model")
            results_adaptor, results_lrsolver = evaluate_adaptor(
                inference_head=adaptor_model,
                embed_model=embed_model,
                embed_tokenizer=embed_tokenizer,
                num_pca_components=num_pca_components,  # for pca
                seed=seed,
                input_key=input_key,
                data_path=data_path,
                dataset=dataset,
                prompt_format=prompt_format,
                random_seed=seed,
                k_range=k_range,
                text_threshold=text_threshold,
                corrupted_type=corrupted_type,
                domain=domain,
            )
            final_results_lrsolver[seed] = results_lrsolver
            final_results_adaptor[seed] = results_adaptor

        run_id = path_to_pretrained_adaptor.split("/")[-2]
        checkpoint = path_to_pretrained_adaptor.split("/")[-1].split(".")[0]

    model_name_split = embed_model_name.split("/")[-1]
    save_path_adaptor = (
        f"{save_dir}/{dataset}/{model_name_split}/adaptor/{run_id}/{checkpoint}/"
    )
    save_path_lrsolver = (
        f"{save_dir}/{dataset}/{model_name_split}/lrsolver/{run_id}/{checkpoint}/"
    )
    if not os.path.exists(save_path_adaptor):
        os.makedirs(save_path_adaptor)
    if not os.path.exists(save_path_lrsolver):
        os.makedirs(save_path_lrsolver)

    file_name = f"{save_path_adaptor}/Adaptor_{corrupted_type}.pkl"
    pickle.dump(final_results_adaptor, open(file_name, "wb"))
    file_name = f"{save_path_lrsolver}/LRSolver_{corrupted_type}.pkl"
    pickle.dump(final_results_lrsolver, open(file_name, "wb"))


def eval_base(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="sentence_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
):
    ### Loading base model
    _, model, tokenizer = load_model(
        path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
    )
    final_results_base = {}
    for rs in [9]:
        for seed in seeds:
            print("evaling trained model")
            results = evaluate_original_or_tuned(
                model,
                tokenizer,
                dataset,
                data_path=data_path,
                input_key=key,
                seed=seed,
                prompt_format=prompt_format,
                random_seed=None,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_base[seed] = results

        run_id = path_to_synthetic_model.split("/")[-2]
        checkpoint = path_to_synthetic_model.split("/")[-1].split(".")[0]
    file_name = f"{save_dir}/{model_name}_eval_{dataset}_rs_{rs}.pkl".replace("/", "-")
    pickle.dump(final_results_base, open(file_name, "wb"))


def eval_lrtuned(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="colon_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=10,
):
    ### Loading base model
    _, model, tokenizer = load_model(
        path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
    )
    final_results_base = {}
    for rs in [9]:
        for seed in [42, 69, 128, 512, 1024]:
            print("evaling trained model")
            results = evaluate_original_or_tuned(
                model,
                tokenizer,
                dataset,
                data_path=data_path,
                input_key=key,
                seed=seed,
                prompt_format=prompt_format,
                random_seed=None,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_base[seed] = results

        run_id = path_to_synthetic_model.split("/")[-2]
        checkpoint = path_to_synthetic_model.split("/")[-1].split(".")[0]

    file_name = f"{save_dir}/LRTuned_{dataset}_runid_{run_id}_ckpt_{checkpoint}_rs_{rs}_{prefix}.pkl"

    pickle.dump(final_results_base, open(file_name, "wb"))


if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="EleutherAI/gpt-neo-125m",
    )
    parser.add_argument(
        "--path_to_adaptor_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--path_to_synthetic_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="sms",
    )
    parser.add_argument(
        "--key",
        type=str,
        default="text",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="./data",
    )
    parser.add_argument(
        "--prompt_format",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--k_range",
        type=str,
        default="[4, 8, 16, 32, 48, 64, 96, 128]",
    )
    parser.add_argument(
        "--seeds",
        type=str,
        default="[42, 69, 128, 512, 1024]",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--n_dims",
        type=int,
        default=16,
    )
    parser.add_argument(
        "--n_positions",
        type=int,
        default=64,
    )

    parser.add_argument(
        "--corrupted_type",
        type=str,
        default=None,
        # add accepted values
        choices=[
            "corrupted_base",
            "corrupted_loo",
            "non_corrupted",
            "corrupted_long",
            "corrupted_long_2x",
        ],
    )
    parser.add_argument(
        "--domain",
        type=str,
        default=None,
        # add accepted values
        choices=[
            "image",
            "audio",
        ],
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./outputs",
    )
    parser.add_argument(
        "--text_threshold",
        type=int,
        default=100,
    )
    parser.add_argument("--num_pca_components", type=int, default=8)
    parser.add_argument("--eval_adaptor", action="store_true")
    parser.add_argument("--eval_base", action="store_true")
    parser.add_argument("--eval_lrtuned", action="store_true")

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if args.eval_adaptor:
        eval_adaptor(
            embed_model_name=args.base_model_name,
            path_to_pretrained_adaptor=args.path_to_adaptor_model,
            dataset=args.dataset,
            input_key=args.key,
            data_path=args.data_path,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            n_dims=args.n_dims,
            n_positions=args.n_positions,
            num_pca_components=args.num_pca_components,
            prompt_format=args.prompt_format,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
            seeds=eval(args.seeds),
            corrupted_type=args.corrupted_type,
            domain=args.domain,
        )
    if args.eval_base:
        eval_base(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
            seeds=eval(args.seeds),
        )
    if args.eval_lrtuned:
        eval_lrtuned(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
        )
