import os
import time
import torch
import inspect
import argparse
import pandas as pd
from typing import Literal, List
from tqdm.auto import tqdm

from needle import NeedleConfig
from _models.model import get_embedding_func_batched
from _datasets.data import DatasetConfig
from utils.transform_utils import *
from utils.string_utils import *


class ExperimentConfig:

    def __init__(
        self,
        mode: Literal["insert", "remove"],
        dataset_name: str,
        num_examples: int,
        needle_keywords: List[str],
        needle_sizes: List[float],
        needle_posns: List[float],
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,  # TODO: automate getting this number based on the model name
    ):
        """
        Initializes the ExperimentConfig object.

        Parameters:
            mode: The mode parameter.
            dataset_name: The name of the dataset.
            num_examples: The number of examples.
            needle_corpus_keyword: The keyword for the needle corpus.
            needle_sizes: The sizes of the needle.
            needle_posns: The positions of the needle.
            model_name: The name of the model (default is "BAAI/bge-small-en-v1.5").
            data_max_tokens: The maximum number of tokens for the data (default is 8192).

        Returns:
            None
        """
        assert mode in ["insert", "remove"], "Mode must be 'insert' or 'remove'."

        self.mode = mode
        self.model_name = model_name
        self.dataset_config = DatasetConfig(dataset_name, num_examples)
        self.dataset = self.dataset_config.get_dataset(True, max_length)
        print(f"Dataset {dataset_name} loaded.")

        # Truncate dataset examples to max_text_size
        if mode == "insert":
            max_text_size = int(max_length // (1 + max(needle_sizes)))
            data_num_tokens = self.dataset["original"].apply(num_tokens)
            # self.dataset = self.dataset["original"][data_num_tokens < max_text_size]
            self.dataset["original"] = self.dataset["original"].apply(lambda text: truncate(text, max_text_size))

            self.dataset = self.dataset.reset_index(drop=True)
            print(f"Filtered dataset to {len(self.dataset)} examples.")

        self.needle_configs = NeedleConfig(
            needle_keywords, needle_sizes, needle_posns, mode=mode
        )
        self.needle_configs = self.needle_configs.get_configs()
        print(f"Loaded {len(self.needle_configs)} needle configs.")

        self.embedding_func = get_embedding_func_batched(model_name)
        self.similarity_data = pd.DataFrame(self.dataset)

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def run(self):
        self.generate_text_with_needles()
        print("Added needles to dataframe.")

        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.generate_sentence_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated sentence embeddings.")

        self.calculate_similarities()

        # Save the similarity data to a CSV file in the model-specific directory
        file_path = f"{self.model_data_path}/{self.dataset_config.name}_{self.mode}.pkl"
        self.similarity_data.to_pickle(file_path)
        print(f"Saved data to {file_path}.")

    # adds needles with all specified needle configs to data
    def generate_text_with_needles(self):
        """
        Adds to the experiemnt dataframe based on the experiment mode and needle configurations.
        """
        ablation_method = self.get_ablation_method(self.mode)

        for config in self.needle_configs:
            config_name = config["name"]
            col_name = f"text_{config_name}"
            self.similarity_data[col_name] = self.similarity_data["original"].apply(
                lambda x: ablation_method(x, **config["params"])
            )  # replace with appropriate keys for config

    def get_ablation_method(self, mode):
        """
        Determines the appropriate ablation method based on the specified mode.

        Parameters:
            self: The ExperimentConfig object.
            mode (str): The mode specifying whether to insert or remove needles.

        Returns:
            function: The corresponding ablation method based on the mode.

        Raises:
            ValueError: If the mode is neither 'insert' nor 'remove'.
        """
        if mode == "insert":
            return self.add_needle_single
        elif mode == "remove":
            return self.add_removal_single
        return ValueError("Mode must be 'insert' or 'remove'.")

    # generates emebeddings for every specified needle config after the needle is added
    def generate_embeddings(self, embedding_func, **kwargs):
        """
        Generate embeddings for the original text column and each modified text column.

        Parameters:
            embedding_func: The function used for generating embeddings.
            **kwargs: Additional keyword arguments for the embedding function.

        Returns:
            None
        """
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        embeds = embedding_func(
            prompts=self.similarity_data["original"].dropna().tolist(),
            pbar=False,
            **kwargs,
        )
        self.similarity_data["embeddings_original"] = (
            embeds if isinstance(embeds, list) else embeds.tolist()
        )

        shuffled_embeds = embedding_func(
            prompts=self.similarity_data["original"].dropna().apply(lambda text: shuffle_text(text, spacing=False)).tolist(),
            pbar=False,
            **kwargs,
        )
        self.similarity_data["embeddings_shuffled"] = (
            shuffled_embeds if isinstance(shuffled_embeds, list) else shuffled_embeds.tolist()
        )


        # Generate embeddings for each modified text column.
        for config in tqdm(self.needle_configs, desc="needle configurations"):
            needle_column = f"text_{config['name']}"
            embeddings_column = f"embeddings_{config['name']}"

            if needle_column in self.similarity_data:
                embeds = embedding_func(
                    prompts=self.similarity_data[needle_column].dropna().tolist(),
                    pbar=False,
                    **kwargs,
                )
                self.similarity_data[embeddings_column] = (
                    embeds if isinstance(embeds, list) else embeds.tolist()
                )
            else:
                print(
                    f"Warning: Column {needle_column} does not exist in the DataFrame"
                )

    def generate_sentence_embeddings(self, embedding_func, **kwargs):
        """
        Generate sentence embeddings using the provided embedding function and keyword arguments.

        Parameters:
            embedding_func: The function used for generating embeddings.
            **kwargs: Additional keyword arguments for the embedding function.

        Returns:
            None
        """
        self.similarity_data["sentences"] = get_sentences(
            self.similarity_data["original"].to_list()
        )
        self.similarity_data["sentence_positions"] = self.similarity_data[
            "sentences"
        ].apply(get_sentence_positions)
        self.similarity_data["sentence_proportions"] = self.similarity_data[
            "sentences"
        ].apply(get_sentence_proportions)
        self.similarity_data["num_sentences"] = self.similarity_data["sentences"].apply(
            len
        )
        sentence_embeddings = []

        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        for i in tqdm(range(0, len(self.similarity_data)), desc="sentence embeddings"):
            sentences = self.similarity_data["sentences"][i]
            sentence_embedding = embedding_func(prompts=sentences, pbar=False, **kwargs)
            sentence_embeddings.append(sentence_embedding)

        self.similarity_data["sentence_embeddings"] = sentence_embeddings

    def calculate_similarities(self):
        """
        Calculate cosine similarity between original text embeddings and each needle config's embedding.
        """
        original_embeddings = torch.tensor(
            self.similarity_data["embeddings_original"].tolist()
        ).to("cpu")

        for config in self.needle_configs:
            embeddings_column = f"embeddings_{config['name']}"
            comparison_embeddings = torch.tensor(
                self.similarity_data[embeddings_column].tolist()
            ).to("cpu")
            self.similarity_data[f"cosine_similarity_{config['name']}"] = (
                torch.cosine_similarity(
                    original_embeddings, comparison_embeddings, dim=1
                )
            )

    def add_needle_single(self, text, corpus, length, posn):
        """
        A function to add a specified portion of text (needle) to the input text at a given position.

        Parameters:
            text: The input text to which the needle will be added.
            corpus: The text from which the needle will be derived.
            length: The relative length of the needle compared to the total text length.
            posn: The position within the text where the needle should be added as a fraction of the total text length.

        Returns:
            The text after adding the needle at the specified position.
        """
        assert corpus is not None, "Corpora is required for needle addition."
        if not 0 <= posn <= 1:
            raise ValueError("Percent location must be between 0 and 1 (inclusive).")

        n_example_tokens = num_tokens(text)
        needle_token_length = max(int(length * n_example_tokens), 1)
        needle = truncate(corpus, needle_token_length)
        loc = int(posn * len(text))
        return text[:loc] + needle + text[loc:]

    def add_removal_single(self, text, corpus, length, posn):
        """
        A function to remove a specified portion of text from the input text.

        Parameters:
            text: The input text from which to remove a portion.
            corpus: This should be None, as needle removal is not supported for corpora.
            length: The length of the portion to be removed as a fraction of the total text length.
            posn: The position within the text where the removal should start as a fraction of the total text length.

        Returns:
            The text after removing the specified portion based on the length and position.
        """
        assert corpus is None, "Needle removal is not supported for corpora."
        assert length <= 1, "Length must be less than or equal to 1."
        if posn == 0 and length == 1:
            raise ValueError("Cannot remove the entire text.")
        if not 0 <= posn <= 1:
            raise ValueError("Percent location must be between 0 and 1 (inclusive).")

        encoding = tokenizer.encode(text)
        adjusted_posn = posn - length * posn
        loc = int(num_tokens(text) * adjusted_posn)
        length = int(num_tokens(text) * length)
        return tokenizer.decode(encoding[:loc] + encoding[loc + length :])


def main(
    mode="insert",
    dataset_name="paul_graham",
    num_examples=5,
    needle_keywords=["lorem"],
    needle_size=[0.1],
    needle_posn=[0.5],
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = ExperimentConfig(
        mode,
        dataset_name,
        num_examples,
        needle_keywords,
        needle_size,
        needle_posn,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="insert")
    parser.add_argument("--dataset_name", type=str, default="paul_graham")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--needle_keyword", type=list, default=["lorem"])
    parser.add_argument("--needle_size", type=list, nargs="+", default=[0.1])
    parser.add_argument("--needle_posn", type=list, nargs="+", default=[0.5])
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    mode = args.mode
    dataset_name = args.dataset_name
    num_examples = args.num_examples
    needle_keywords = args.needle_keyword
    needle_size = args.needle_size
    needle_posn = args.needle_posn
    model_name = args.model_name
    max_length = args.max_length

    main(
        mode,
        dataset_name,
        num_examples,
        needle_keywords,
        needle_size,
        needle_posn,
        model_name,
        max_length,
    )
