from transformers import (
    AutoTokenizer,
    GPTNeoForCausalLM,
    AutoConfig,
    AutoModelForCausalLM,
)
import torch
import argparse
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.metrics import f1_score
import pickle
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from models import TransformerLanguageModel
from sklearn.decomposition import PCA, SparsePCA
from models import TransformerModel
from typing import List, Dict, Tuple, Union
from reasoning_utils import (
    compute_pca,
    get_non_corrup_embeds,
    compute_pca_non_corrupt,
    get_pca_embeds_repeat,
)


import os


def load_model(
    model_name, path_to_synthetic_model=None, frozen_model=False, ordering=False
):
    config = AutoConfig.from_pretrained(model_name)

    tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|pad|>")
    tokenizer.truncation_side = "left"

    # should work for all generative LM models
    trained_model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_name,
        config=config,
    )
    base_model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=model_name,
        config=config,
    )
    # load weights
    if path_to_synthetic_model is not None:
        synthetic_model = torch.load(path_to_synthetic_model)

        new_state_dict = {}
        if "pythia" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.gpt_neox"):
                    k = k.replace("_backbone.gpt_neox", "gpt_neox")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_out"):
                    k = k.replace("_backbone.embed_out", "embed_out")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.embed_in"):
                    k = k.replace("_backbone.embed_in", "embed_in")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        elif "opt" in model_name:
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.model"):
                    k = k.replace("_backbone.model", "model")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v
        else:
            print("loading synthetic")
            for k, v in synthetic_model.items():
                if k.startswith("_backbone.transformer"):
                    k = k.replace("_backbone.transformer", "transformer")
                    new_state_dict[k] = v
                elif k.startswith("_backbone.lm_head"):
                    k = k.replace("_backbone.lm_head", "lm_head")
                    new_state_dict[k] = v
                else:
                    new_state_dict[k] = v

        # load state dict
        trained_model.load_state_dict(new_state_dict, strict=False)

    trained_model.eval()
    base_model.eval()

    return (
        trained_model.cuda(),
        base_model.cuda(),
        tokenizer,
    )


# Data load function
def load_data(data_path, dataset="sms", input_key="sms", seed=42):
    # load training data
    train_set_path = os.path.join(data_path, f"{dataset}/train_samples_s{seed}.csv")
    training_set = pd.read_csv(train_set_path)
    X_train = training_set[input_key].tolist()
    y_train = training_set["label"].tolist()

    # load test data
    if "ag_news" in dataset or "dbpedia" in dataset or "civil_comments" in dataset:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_bal.csv")
    else:
        test_set_path = os.path.join(data_path, f"{dataset}/test_samples_orig.csv")
    test_set = pd.read_csv(test_set_path)
    X_test = test_set[input_key].tolist()
    y_test = test_set["label"].tolist()

    # return
    return (X_train, y_train), (X_test, y_test)


def generate_in_context_example(
    X_train, y_train, template, seed, order=None, text_threshold=100
):
    samples = list(zip(X_train, y_train))
    # generate a random permutation of 0 to 128
    if not order:
        # set numpy random seed
        np.random.seed(seed)
        order = np.random.permutation(len(samples))

    in_context_example = ""
    ys = []
    for i in order:
        sample = samples[int(i)]
        if sample[1] == 1:
            label = "positive"
            ys.append(sample[1])
        else:
            label = "negative"
            ys.append(sample[1])
        in_context_example += template.format(
            sentence=sample[0][0:text_threshold].strip(), label=label
        )

    return in_context_example, ys


def get_template(template):
    if template == "sentence_label":
        train_template = f"Sentence: {{sentence:}}\nLabel: {{label:}}\n"
        test_template = f"Sentence: {{sentence:}}\nLabel:"
    elif template == "colon_label":
        train_template = f"{{sentence:}} : {{label:}} , "
        test_template = f"{{sentence:}} :"
    return train_template, test_template


def evaluate_original_or_tuned(
    model,
    tokenizer,
    dataset,
    data_path,
    input_key,
    seed=42,
    prompt_format="sentence_label",
    random_seed=None,
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    text_threshold=100,
):
    """
    Evaluates adaptor model and LR solver model on dataset

    embed_model: the model used to embed the in context examples
    inference_head: torch module used for inference
    """
    sigmoid = torch.nn.Sigmoid()
    map_label = {0: "negative", 1: "positive"}
    postive_token_id_no_space = tokenizer("positive").input_ids[0]
    negative_token_id_no_space = tokenizer("negative").input_ids[0]
    positive_token_id_space = tokenizer(" positive").input_ids[0]
    negative_token_id_space = tokenizer(" negative").input_ids[0]

    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key, seed
    )

    results = {0: {}}
    results_pca = {0: {}}

    with torch.no_grad():
        for k in k_range:
            (
                gt_label,
                predicted_label,
                original_text,
                predicted_text,
                predicted_scores,
            ) = ([], [], [], [], [])

            results[0][k] = None
            train_template, test_template = get_template(prompt_format)
            in_context_example, ys = generate_in_context_example(
                X_train[0:k],
                y_train[0:k],
                template=train_template,
                seed=(seed if random_seed is None else random_seed),
                order=None,
                text_threshold=text_threshold,
            )

            for test_idx, (text, label) in tqdm(enumerate(zip(X_test, y_test))):
                current_sample = test_template.format(
                    sentence=text[0:text_threshold].strip(), label=""
                )

                prompt = f"{in_context_example}{current_sample}"
                input_ids = tokenizer(
                    prompt, max_length=2048, truncation=True, return_tensors="pt"
                ).input_ids
                curr_example = current_sample
                current_ids = tokenizer(
                    curr_example,
                    max_length=2048,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids

                sample_outputs = model.generate(
                    input_ids.cuda(),
                    do_sample=False,
                    temperature=0.0,
                    max_new_tokens=1,
                    output_scores=True,
                    return_dict_in_generate=True,
                )

                pred_text = tokenizer.decode(
                    sample_outputs["sequences"][0], skip_special_tokens=True
                )

                logits = torch.softmax(sample_outputs["scores"][0], axis=-1)
                pos_score_space = logits[:, positive_token_id_space].item()
                pos_score_no_space = logits[:, postive_token_id_no_space].item()
                neg_score_space = logits[:, negative_token_id_space].item()
                neg_score_no_space = logits[:, negative_token_id_no_space].item()

                pred_label = pred_text.split(":")[-1].strip()
                # extract the predicted sentiment

                # append results
                if label in map_label:
                    gt_label.append(map_label[label])
                else:
                    gt_label.append(label)
                predicted_label.append(pred_label)
                original_text.append(text)
                predicted_text.append(pred_text)
                predicted_scores.append(
                    (
                        pos_score_space,
                        pos_score_no_space,
                        neg_score_space,
                        neg_score_no_space,
                    )
                )
            results[0][k] = {
                "original_text": original_text,
                "predicted_label": predicted_label,
                "gt_label": gt_label,
                "predicted_scores": predicted_scores,
                "accuracy": sum(
                    [1 if x == y else 0 for x, y in zip(gt_label, predicted_label)]
                )
                / len(gt_label),
            }
            print(f"Accuracy: {results[0][k]['accuracy']}")
    return results


def evaluate_adaptor(
    inference_head,
    embed_model,
    embed_tokenizer,
    num_pca_components=8,  # for pca
    seed=42,
    input_key="sms",
    data_path="./data",
    dataset="sms",
    prompt_format="sentence_label",
    random_seed=None,
    corrupted_embeds=False,
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    text_threshold=100,
):
    """
    Evaluates adaptor model and LR solver model on dataset

    embed_model: the model used to embed the in context examples
    inference_head: torch module used for inference
    """
    sigmoid = torch.nn.Sigmoid()
    map_label = {0: "negative", 1: "positive"}

    (X_train, y_train), (X_test, y_test) = load_data(
        data_path, dataset, input_key, seed
    )

    results = {0: {}}
    results_pca = {0: {}}

    with torch.no_grad():
        for k in k_range:
            (
                gt_label_pca,
                predicted_label_pca,
                original_text_pca,
                predicted_text_pca,
                predicted_scores_pca,
            ) = ([], [], [], [], [])
            results[0][k] = None

            if corrupted_embeds:
                X_tr_embed_cor, X_tst_embed_cor, y_tr_shuffle = get_pca_embeds_repeat(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train[0:k],
                    y_train[0:k],
                    k,
                    thresh=text_threshold,
                    mini_batch=32,
                    seed=seed,
                )
                # compute pca
                X_tr_pca_cor_white, X_tst_pca_cor_white, pca_pred = compute_pca(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_tr_shuffle,
                    y_test,
                )
                # compute adaptor acc
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_tr_shuffle).float()
                y_test_torch = torch.Tensor(y_test).float()
            else:
                X_tr_embed_cor, X_tst_embed_cor, _ = get_non_corrup_embeds(
                    embed_tokenizer,
                    embed_model,
                    X_test,
                    X_train[0:k],
                    y_train[0:k],
                    k,
                    thresh=text_threshold,
                )
                # compute pca
                (
                    X_tr_pca_cor_white,
                    X_tst_pca_cor_white,
                    pca_pred,
                ) = compute_pca_non_corrupt(
                    num_pca_components,
                    X_tr_embed_cor,
                    X_tst_embed_cor,
                    y_train[0:k],
                    k,
                    y_test,
                )
                X_tr_pca_cor_white_torch = torch.from_numpy(X_tr_pca_cor_white).float()
                X_tst_pca_cor_white_torch = torch.from_numpy(
                    X_tst_pca_cor_white
                ).float()
                y_tr_shuffle_torch = torch.Tensor(y_train[0:k]).float()
                y_test_torch = torch.Tensor(y_test).float()

            for test_idx, (text, label) in tqdm(enumerate(zip(X_test, y_test))):
                if pca_pred[test_idx] == 0:
                    pred_text_pca = "negative"
                    pred_label_pca = "negative"
                else:
                    pred_text_pca = "positive"
                    pred_label_pca = "positive"
                predicted_scores_pca.append(0)
                predicted_label_pca.append(pred_label_pca)
                original_text_pca.append(text)
                predicted_text_pca.append(pred_text_pca)
                if label in map_label:
                    gt_label_pca.append(map_label[label])
                else:
                    gt_label_pca.append(label)

            results[0][k] = {
                "original_text": original_text_pca,
                "predicted_label": predicted_label_pca,
                "gt_label": gt_label_pca,
                # "predicted_text": predicted_text_pca,
                "predicted_scores": predicted_scores_pca,
                "accuracy": sum(
                    [
                        1 if x == y else 0
                        for x, y in zip(gt_label_pca, predicted_label_pca)
                    ]
                )
                / len(gt_label_pca),
            }
            print(results[0][k]["accuracy"])
    return results


def load_model_tokenizer(model_name):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        pad_token="<|pad|>",
    )
    tokenizer.truncation_side = "left"
    model = GPTNeoForCausalLM.from_pretrained(
        model_name,
    ).cuda()
    model.eval()
    return model, tokenizer


def eval_lrsolver(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="sentence_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
    num_pca_components=8,
    input_key="sms",
):
    ### Call eval function
    final_results_adaptor = {}
    final_results_lrsolver = {}
    for rs in [9]:
        for seed in [42, 69, 128]:  # 512, 1024]:  # 9, 204, 405, 9205, 2020]:
            print("evaling trained model")
            results_lrsolver = evaluate_adaptor(
                num_pca_components=num_pca_components,  # for pca
                seed=seed,
                input_key=input_key,
                data_path=data_path,
                dataset=dataset,
                prompt_format=prompt_format,
                random_seed=seed,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_lrsolver[seed] = results_lrsolver

        run_id = path_to_synthetic_model.split("/")[-2]
        checkpoint = path_to_synthetic_model.split("/")[-1].split(".")[0]

    model_name_split = model_name.split("/")[-1]
    pickle.dump(final_results_adaptor, open(file_name, "wb"))
    file_name = f"{save_dir}/LRSolver_{dataset}_runid_{run_id}_ckpt_{checkpoint}_rs_{rs}_m{model_name_split}_{prefix}.pkl"
    pickle.dump(final_results_lrsolver, open(file_name, "wb"))


def eval_base(
    model_name="EleutherAI/gpt-neo-1.3B",
    path_to_synthetic_model=None,
    dataset="sms",
    key="sms",
    data_path="./data",
    prompt_format="sentence_label",
    k_range=[4, 8, 16, 32, 48, 64, 96, 128],
    prefix="",
    save_dir="./outputs",
    text_threshold=100,
    seeds=[42, 69, 128, 512, 1024],
):
    ### Loading base model
    if path_to_synthetic_model is not None:
        print("loading synthetic")
        model, _, tokenizer = load_model(
            path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
        )
    else:
        _, model, tokenizer = load_model(
            path_to_synthetic_model=path_to_synthetic_model, model_name=model_name
        )

    final_results_base = {}
    for rs in [9]:
        for seed in seeds:  # 9, 204, 405, 9205, 2020]:
            print("evaling trained model")
            results = evaluate_original_or_tuned(
                model,
                tokenizer,
                dataset,
                data_path=data_path,
                input_key=key,
                seed=seed,
                prompt_format=prompt_format,
                random_seed=None,
                k_range=k_range,
                text_threshold=text_threshold,
            )
            final_results_base[seed] = results

    model_name_split = model_name.split("/")[-1]
    save_path_base = f"{save_dir}/{dataset}/{model_name_split}/base/"
    if not os.path.exists(save_path_base):
        os.makedirs(save_path_base)

    file_name = f"{save_path_base}/base_tt{text_threshold}.pkl"
    pickle.dump(final_results_base, open(file_name, "wb"))


if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--base_model_name",
        type=str,
        default="EleutherAI/gpt-neo-125m",
    )
    parser.add_argument(
        "--path_to_adaptor_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--path_to_synthetic_model",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="sms",
    )
    parser.add_argument(
        "--key",
        type=str,
        default="text",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="./data",
    )
    parser.add_argument(
        "--prompt_format",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--k_range",
        type=str,
        default="[4, 8, 16, 32, 48, 64, 96, 128]",
    )
    parser.add_argument(
        "--seeds",
        type=str,
        default="[4, 8, 16, 32, 48, 64, 96, 128]",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--n_dims",
        type=int,
        default=16,
    )
    parser.add_argument(
        "--n_positions",
        type=int,
        default=64,
    )
    parser.add_argument(
        "--corrupted_embeds",
        type=bool,
        default=True,
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./outputs",
    )
    parser.add_argument(
        "--text_threshold",
        type=int,
        default=100,
    )
    parser.add_argument("--num_pca_components", type=int, default=8)
    parser.add_argument("--eval_base", action="store_true")
    parser.add_argument("--eval_lrsolver", action="store_true")

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if args.eval_base:
        eval_base(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
            seeds=eval(args.seeds),
        )
    if args.eval_lrsolver:
        eval_lrsolver(
            model_name=args.base_model_name,
            path_to_synthetic_model=args.path_to_synthetic_model,
            dataset=args.dataset,
            key=args.key,
            data_path=args.data_path,
            prompt_format=args.prompt_format,
            k_range=eval(args.k_range),
            prefix=args.prefix,
            save_dir=args.save_dir,
            text_threshold=args.text_threshold,
        )
