import os
import json
import torch
import inspect
import argparse
from typing import Literal, List
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

from needle import NeedleConfig
from _datasets.data import DatasetConfig
from _models.model import get_embedding_func_batched
from utils.transform_utils import *
from utils.string_utils import *
from utils.metrics import *


class SensitivityExperimentConfig:

    def __init__(
        self,
        mode: Literal["insert", "remove"],
        dataset_name: str,
        num_examples: int,
        needle_keywords: List[str],
        needle_sizes: List[float],
        needle_posns: List[float],
        model_name: str = "BAAI/bge-small-en-v1.5",
        max_length: int = 8192,  # TODO: automate getting this number based on the model name
    ):
        """
        Initializes the ExperimentConfig object.

        Parameters:
            mode: The mode parameter.
            dataset_name: The name of the dataset.
            num_examples: The number of examples.
            needle_corpus_keyword: The keyword for the needle corpus.
            needle_sizes: The sizes of the needle.
            needle_posns: The positions of the needle.
            model_name: The name of the model (default is "BAAI/bge-small-en-v1.5").
            data_max_tokens: The maximum number of tokens for the data (default is 8192).

        Returns:
            None
        """
        assert mode in ["insert", "remove"], "Mode must be 'insert' or 'remove'."

        self.mode = mode
        self.model_name = model_name
        self.dataset_config = DatasetConfig(dataset_name, num_examples)
        self.dataset = self.dataset_config.get_dataset(True, max_length)
        print(f"Dataset {dataset_name} loaded.")

        # Filter dataset to only have text with less than max_text_size tokens
        if mode == "insert":
            max_text_size = int(max_length // (1 + max(needle_sizes)))
            data_num_tokens = self.dataset["original"].apply(num_tokens)
            self.dataset = self.dataset["original"][data_num_tokens < max_text_size]
            self.dataset = self.dataset.reset_index(drop=True)
            print(f"Filtered dataset to {len(self.dataset)} examples.")

        self.needle_configs = NeedleConfig(
            needle_keywords, needle_sizes, needle_posns, mode=mode
        )
        self.needle_configs = self.needle_configs.get_configs()
        print(f"Loaded {len(self.needle_configs)} needle configs.")

        self.embedding_func = get_embedding_func_batched(model_name)
        self.similarity_data = pd.DataFrame(self.dataset)
        self.results = {}

        # Create directory for model data if it doesn't exist
        self.model_data_path = os.path.join(
            "data", self.model_name.replace("/", "_")
        )  # Replacing '/' with '_' to avoid subdirectories
        os.makedirs(self.model_data_path, exist_ok=True)

    def run(self):
        self.generate_text_with_needles()
        print("Added needles to dataframe.")

        self.generate_embeddings(
            embedding_func=self.embedding_func,
            **{"model_name": self.model_name, "use_gpu": True},
        )
        print("Generated embeddings.")

        self.calculate_similarities()
        print("Calculated similarities.")

        self.fit_ensembling()
        print("Fitted ensembling.")

        self.get_results()
        print("Got results.")

        # Save the similarity data to a CSV file in the model-specific directory
        data_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_{self.mode}.pkl"
        )
        self.similarity_data.to_pickle(data_file_path)
        print(f"Saved data to {data_file_path}.")

        # Save the results to a JSON file in the model-specific directory
        results_file_path = (
            f"{self.model_data_path}/{self.dataset_config.name}_{self.mode}.json"
        )
        with open(results_file_path, "w") as f:
            self.results = {k: float(v) for k, v in self.results.items()}
            f.write(json.dumps(self.results))
        print(f"Saved results to {results_file_path}.")

    # adds needles with all specified needle configs to data
    def generate_text_with_needles(self):
        """
        Adds to the experiemnt dataframe based on the experiment mode and needle configurations.
        """
        ablation_method = self.get_ablation_method(self.mode)

        for config in self.needle_configs:
            config_name = config["name"]
            col_name = f"text_{config_name}"
            self.similarity_data[col_name] = self.similarity_data["original"].apply(
                lambda x: ablation_method(x, **config["params"])
            )  # replace with appropriate keys for config

    def get_ablation_method(self, mode):
        """
        Determines the appropriate ablation method based on the specified mode.

        Parameters:
            self: The ExperimentConfig object.
            mode (str): The mode specifying whether to insert or remove needles.

        Returns:
            function: The corresponding ablation method based on the mode.

        Raises:
            ValueError: If the mode is neither 'insert' nor 'remove'.
        """
        if mode == "insert":
            return self.add_needle_single
        elif mode == "remove":
            return self.add_removal_single
        return ValueError("Mode must be 'insert' or 'remove'.")

    # generates emebeddings for every specified needle config after the needle is added
    def generate_embeddings(self, embedding_func, **kwargs):
        """
        Generate embeddings for the original text column and each modified text column.

        Parameters:
            embedding_func: The function used for generating embeddings.
            **kwargs: Additional keyword arguments for the embedding function.

        Returns:
            None
        """
        # For models that are not from huggingface
        source_code = inspect.getsource(embedding_func)
        if not "huggingface" in source_code:
            kwargs["model"] = kwargs["model_name"]
            del kwargs["model_name"]
            del kwargs["use_gpu"]

        embeds = embedding_func(
            prompts=self.similarity_data["original"].dropna().tolist(),
            pbar=False,
            **kwargs,
        )
        self.similarity_data["embeddings_original"] = (
            embeds if isinstance(embeds, list) else embeds.tolist()
        )

        # Generate embeddings for each modified text column.
        for config in tqdm(self.needle_configs, desc="needle configurations"):
            needle_column = f"text_{config['name']}"
            embeddings_column = f"embeddings_{config['name']}"

            if needle_column in self.similarity_data:
                embeds = embedding_func(
                    prompts=self.similarity_data[needle_column].dropna().tolist(),
                    pbar=False,
                    **kwargs,
                )
                self.similarity_data[embeddings_column] = (
                    embeds if isinstance(embeds, list) else embeds.tolist()
                )
            else:
                print(
                    f"Warning: Column {needle_column} does not exist in the DataFrame"
                )

    def calculate_similarities(self):
        """
        Calculate similarity between original text embeddings and each needle config's embedding.
        """
        original_embeddings = self.similarity_data["embeddings_original"]

        print("Calculating similarities...")
        for config in tqdm(self.needle_configs, desc="needle configurations"):
            embeddings_column = f"embeddings_{config['name']}"
            text_column = f"text_{config['name']}"
            comparison_embeddings = self.similarity_data[embeddings_column]
            comparison_text = self.similarity_data[text_column]
            self.similarity_data[f"cosine_similarity_{config['name']}"] = (
                cosine_similarity(original_embeddings, comparison_embeddings)
            )
            self.similarity_data[f"levenshtein_similarity_{config['name']}"] = (
                levenshtein_ratio(self.similarity_data["original"], comparison_text)
            )
            self.similarity_data[f'rouge_similarity_{config["name"]}'] = rouge_score(
                self.similarity_data["original"], comparison_text
            )
            self.similarity_data[f'bm25_similarity_{config["name"]}'] = bm25_score(
                self.similarity_data["original"], comparison_text
            )
            self.similarity_data[f'jaccard_similarity_{config["name"]}'] = (
                jaccard_similarity(self.similarity_data["original"], comparison_text)
            )

    def fit_ensembling(self):
        similarities = {}
        expected_values = []
        for metric in metrics:
            similarities[metric] = []

        for config in self.needle_configs:
            needle_length = config["params"]["length"]
            expected = 1 - (needle_length / (1 + needle_length))
            expected_values.extend([expected] * len(self.similarity_data))

            for metric in metrics:
                metric_name = f"{metric}_similarity"
                similarities[metric].extend(
                    self.similarity_data[f"{metric_name}_{config['name']}"]
                )

        X = np.array([similarities[metric] for metric in metrics]).T
        y = np.array(expected_values)

        ensembled_scores = []
        for i in tqdm(range(1000), desc="Ensembling"):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=i
            )

            ensemble = LinearRegression(fit_intercept=False)
            ensemble.fit(X_train, y_train)
            predictions = ensemble.predict(X_test)
            ensembled_score = self.score(y_test, predictions)
            ensembled_scores.append(ensembled_score)

        self.results["ensembled"] = np.mean(ensembled_scores)

    def score(self, x, y):
        expected = x
        mae = np.mean(np.abs(y - expected))
        return 1 - mae

    def get_results(self):
        for metric in metrics:
            self.results[metric] = []

        for config in self.needle_configs:
            needle_length = config["params"]["length"]
            expected = 1 - (needle_length / (1 + needle_length))

            for metric in metrics:
                metric_name = f"{metric}_similarity"
                metric_score = self.score(
                    expected,
                    self.similarity_data[f"{metric_name}_{config['name']}"],
                )
                self.results[metric].append(metric_score)

        for metric in metrics:
            self.results[metric] = np.mean(self.results[metric])

        return self.results

    def add_needle_single(self, text, corpus, length, posn):
        """
        A function to add a specified portion of text (needle) to the input text at a given position.

        Parameters:
            text: The input text to which the needle will be added.
            corpus: The text from which the needle will be derived.
            length: The relative length of the needle compared to the total text length.
            posn: The position within the text where the needle should be added as a fraction of the total text length.

        Returns:
            The text after adding the needle at the specified position.
        """
        assert corpus is not None, "Corpora is required for needle addition."
        if not 0 <= posn <= 1:
            raise ValueError("Percent location must be between 0 and 1 (inclusive).")

        n_example_tokens = num_tokens(text)
        needle_token_length = max(int(length * n_example_tokens), 1)
        needle = truncate(corpus, needle_token_length)
        loc = int(posn * len(text))
        return text[:loc] + needle + text[loc:]

    def add_removal_single(self, text, corpus, length, posn):
        """
        A function to remove a specified portion of text from the input text.

        Parameters:
            text: The input text from which to remove a portion.
            corpus: This should be None, as needle removal is not supported for corpora.
            length: The length of the portion to be removed as a fraction of the total text length.
            posn: The position within the text where the removal should start as a fraction of the total text length.

        Returns:
            The text after removing the specified portion based on the length and position.
        """
        assert corpus is None, "Needle removal is not supported for corpora."
        assert length <= 1, "Length must be less than or equal to 1."
        if posn == 0 and length == 1:
            raise ValueError("Cannot remove the entire text.")
        if not 0 <= posn <= 1:
            raise ValueError("Percent location must be between 0 and 1 (inclusive).")

        encoding = tokenizer.encode(text)
        adjusted_posn = posn - length * posn
        loc = int(num_tokens(text) * adjusted_posn)
        length = int(num_tokens(text) * length)
        return tokenizer.decode(encoding[:loc] + encoding[loc + length :])


def main(
    mode="insert",
    dataset_name="paul_graham",
    num_examples=5,
    needle_keywords=["lorem"],
    needle_size=[0.1],
    needle_posn=[0.5],
    model_name="embed-english-v3.0",
    max_length=8192,
):
    exp_config = SensitivityExperimentConfig(
        mode,
        dataset_name,
        num_examples,
        needle_keywords,
        needle_size,
        needle_posn,
        model_name,
        max_length,
    )
    exp_config.run()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, default="insert")
    parser.add_argument("--dataset_name", type=str, default="paul_graham")
    parser.add_argument("--num_examples", type=int, default=5)
    parser.add_argument("--needle_keyword", type=list, default=["lorem"])
    parser.add_argument("--needle_size", type=list, nargs="+", default=[0.1])
    parser.add_argument("--needle_posn", type=list, nargs="+", default=[0.5])
    parser.add_argument("--model_name", type=str, default="embed-english-v3.0")
    parser.add_argument("--max_length", type=int, default=8192)
    args = parser.parse_args()

    mode = args.mode
    dataset_name = args.dataset_name
    num_examples = args.num_examples
    needle_keywords = args.needle_keyword
    needle_size = args.needle_size
    needle_posn = args.needle_posn
    model_name = args.model_name
    max_length = args.max_length

    main(
        mode,
        dataset_name,
        num_examples,
        needle_keywords,
        needle_size,
        needle_posn,
        model_name,
        max_length,
    )
