from __future__ import annotations

import csv
import logging
import os
from contextlib import nullcontext
from typing import TYPE_CHECKING, Literal

import numpy as np
import torch.nn.functional as F  # noqa
from scipy.stats import pearsonr, spearmanr
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sklearn.metrics.pairwise import (
    paired_euclidean_distances,
    paired_manhattan_distances,
)

from box_similarity import (
    similarity_function_entailment_pairwise,
    similarity_function_pairwise,
    vector_entailment_similarity_csdelta,
)
from vector_entailment import VectorEntailmentClassifier
from vector_entailment_diff import VectorEntailmentClassifierDiff

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer

from sentence_transformers.util import (
    pairwise_cos_sim,
    pairwise_dot_score,
    pairwise_euclidean_sim,
    pairwise_manhattan_sim,
)

logger = logging.getLogger(__name__)


class EmbeddingSimilarityEvaluatorDiff(SentenceEvaluator):
    """
    Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation
    in comparison to the gold standard labels.
    The metrics are the cosine similarity as well as euclidean and Manhattan distance
    The returned score is the Spearman correlation with a specified metric.

    Example:
        ::

            from datasets import load_dataset
            from sentence_transformers import SentenceTransformer
            from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

            # Load a model
            model = SentenceTransformer('all-mpnet-base-v2')

            # Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
            eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

            # Initialize the evaluator
            dev_evaluator = EmbeddingSimilarityEvaluator(
                sentences1=eval_dataset["sentence1"],
                sentences2=eval_dataset["sentence2"],
                scores=eval_dataset["score"],
                name="sts_dev",
            )
            results = dev_evaluator(model)
            '''
            EmbeddingSimilarityEvaluator: Evaluating the model on the sts-dev dataset:
            Cosine-Similarity :  Pearson: 0.8806 Spearman: 0.8810
            '''
            print(dev_evaluator.primary_metric)
            # => "sts_dev_pearson_cosine"
            print(results[dev_evaluator.primary_metric])
            # => 0.881019449484294
    """

    def __init__(
        self,
        sentences1: list[str],
        sentences2: list[str],
        scores: list[float],
        batch_size: int = 16,
        main_similarity: str | SimilarityFunction | None = None,
        similarity_fn_names: (
            list[
                Literal[
                    "cosine",
                    "euclidean",
                    "manhattan",
                    "dot",
                    "box_intersection",
                    "csdelta",
                ]
            ]
            | None
        ) = None,
        name: str = "",
        show_progress_bar: bool = False,
        write_csv: bool = True,
        precision: (
            Literal["float32", "int8", "uint8", "binary", "ubinary"] | None
        ) = None,
        truncate_dim: int | None = None,
        classifier: VectorEntailmentClassifier = None,
        classifier_diff: VectorEntailmentClassifierDiff = None,
    ):
        """
        Constructs an evaluator based for the dataset.

        Args:
            sentences1 (List[str]): List with the first sentence in a pair.
            sentences2 (List[str]): List with the second sentence in a pair.
            scores (List[float]): Similarity score between sentences1[i] and sentences2[i].
            batch_size (int, optional): The batch size for processing the sentences. Defaults to 16.
            main_similarity (Optional[Union[str, SimilarityFunction]], optional): The main similarity function to use.
                Can be a string (e.g. "cosine", "dot") or a SimilarityFunction object. Defaults to None.
            similarity_fn_names (List[str], optional): List of similarity function names to use. If None, the
                ``similarity_fn_name`` attribute of the model is used. Defaults to None.
            name (str, optional): The name of the evaluator. Defaults to "".
            show_progress_bar (bool, optional): Whether to show a progress bar during evaluation. Defaults to False.
            write_csv (bool, optional): Whether to write the evaluation results to a CSV file. Defaults to True.
            precision (Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]], optional): The precision
                to use for the embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". Defaults to None.
            truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to. `None` uses the
                model's current truncation dimension. Defaults to None.
        """
        super().__init__()
        self.sentences1 = sentences1
        self.sentences2 = sentences2
        self.scores = scores
        # WARNING this needs to be changed?
        self.write_csv = False
        self.precision = precision
        self.truncate_dim = truncate_dim

        assert len(self.sentences1) == len(self.sentences2)
        assert len(self.sentences1) == len(self.scores)

        self.main_similarity = main_similarity
        self.similarity_fn_names = similarity_fn_names or []
        self.name = name

        self.batch_size = batch_size
        if show_progress_bar is None:
            show_progress_bar = (
                logger.getEffectiveLevel() == logging.INFO
                or logger.getEffectiveLevel() == logging.DEBUG
            )
        self.show_progress_bar = show_progress_bar

        self.csv_file = (
            "similarity_evaluation"
            + ("_" + name if name else "")
            + ("_" + precision if precision else "")
            + "_results.csv"
        )
        self.csv_headers = [
            "epoch",
            "steps",
        ]

        self._append_csv_headers(self.similarity_fn_names)
        self.classifier = classifier
        self.classifier_diff = classifier_diff

    def _append_csv_headers(self, similarity_fn_names: list[str]) -> None:
        metrics = ["pearson", "spearman"]

        for v in similarity_fn_names:
            for m in metrics:
                self.csv_headers.append(f"{v}_{m}")

    @classmethod
    def from_input_examples(cls, examples: list[InputExample], **kwargs):
        sentences1 = []
        sentences2 = []
        scores = []

        for example in examples:
            sentences1.append(example.texts[0])
            sentences2.append(example.texts[1])
            scores.append(example.label)
        return cls(sentences1, sentences2, scores, **kwargs)

    def __call__(
        self,
        model: SentenceTransformer,
        output_path: str = None,
        epoch: int = -1,
        steps: int = -1,
    ) -> dict[str, float]:
        print(f"path of output is here {output_path}")
        print(f"steps = {steps}")
        print(f"epochs = {epoch}")
        if epoch != -1:
            if steps == -1:
                out_txt = f" after epoch {epoch}"
            else:
                out_txt = f" in epoch {epoch} after {steps} steps"
        else:
            out_txt = ""
        if self.truncate_dim is not None:
            out_txt += f" (truncated to {self.truncate_dim})"

        logger.info(
            f"EmbeddingSimilarityEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:"
        )

        with (
            nullcontext()
            if self.truncate_dim is None
            else model.truncate_sentence_embeddings(self.truncate_dim)
        ):
            embeddings1 = model.encode(
                self.sentences1,
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor=True,
                precision=self.precision,
                normalize_embeddings=bool(self.precision),
            )
            embeddings2 = model.encode(
                self.sentences2,
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor=True,
                precision=self.precision,
                normalize_embeddings=bool(self.precision),
            )
        # Binary and ubinary embeddings are packed, so we need to unpack them for the distance metrics
        if self.precision == "binary":
            embeddings1 = (embeddings1 + 128).astype(np.uint8)
            embeddings2 = (embeddings2 + 128).astype(np.uint8)
        if self.precision in ("ubinary", "binary"):
            embeddings1 = np.unpackbits(embeddings1, axis=1)
            embeddings2 = np.unpackbits(embeddings2, axis=1)

        labels = self.scores

        if not self.similarity_fn_names:
            self.similarity_fn_names = [model.similarity_fn_name]
            self._append_csv_headers(self.similarity_fn_names)

        similarity_functions = {
            "cosine": lambda x, y: F.cosine_similarity(x, y, dim=1),
            "manhattan": lambda x, y: -paired_manhattan_distances(x, y),
            "euclidean": lambda x, y: -paired_euclidean_distances(x, y),
            "dot": lambda x, y: [np.dot(emb1, emb2) for emb1, emb2 in zip(x, y)],
            "box_intersection": lambda x, y: similarity_function_pairwise(
                x, y, volume_temp=1.0, intersection_temp=0.001
            ),
            "vector_entailment": lambda x, y: self.classifier.get_scores(x, y).detach(),
            "vector_entailment_diff": lambda x, y: self.classifier_diff.get_scores(
                x, y
            ).detach(),
            "csdelta": lambda x, y: vector_entailment_similarity_csdelta(x, y),
        }

        metrics = {}
        for fn_name in self.similarity_fn_names:
            if fn_name in similarity_functions:
                scores = similarity_functions[fn_name](embeddings1, embeddings2)
                eval_pearson, _ = pearsonr(labels, scores.cpu())
                eval_spearman, _ = spearmanr(labels, scores.cpu())
                metrics[f"pearson_{fn_name}"] = eval_pearson
                metrics[f"spearman_{fn_name}"] = eval_spearman
                logger.info(
                    f"{fn_name.capitalize()}-Similarity :\tPearson: {eval_pearson:.4f}\tSpearman: {eval_spearman:.4f}"
                )

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(
                csv_path,
                newline="",
                mode="a" if output_file_exists else "w",
                encoding="utf-8",
            ) as f:
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)

                writer.writerow(
                    [
                        epoch,
                        steps,
                    ]
                    + [
                        metrics[f"{fn_name}_{m}"]
                        for fn_name in self.similarity_fn_names
                        for m in ["pearson", "spearman"]
                    ]
                )

        if len(self.similarity_fn_names) > 1:
            metrics["pearson_max"] = max(
                metrics[f"pearson_{fn_name}"] for fn_name in self.similarity_fn_names
            )
            metrics["spearman_max"] = max(
                metrics[f"spearman_{fn_name}"] for fn_name in self.similarity_fn_names
            )

        if self.main_similarity:
            self.primary_metric = {
                SimilarityFunction.COSINE: "spearman_cosine",
                SimilarityFunction.EUCLIDEAN: "spearman_euclidean",
                SimilarityFunction.MANHATTAN: "spearman_manhattan",
                SimilarityFunction.DOT_PRODUCT: "spearman_dot",
            }.get(self.main_similarity)
        else:
            if len(self.similarity_fn_names) > 1:
                self.primary_metric = "spearman_max"
            else:
                self.primary_metric = f"spearman_{self.similarity_fn_names[0]}"

        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics

    @property
    def description(self) -> str:
        return "Semantic Similarity"


if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer

logger = logging.getLogger(__name__)


class TripletEvaluatorDiff(SentenceEvaluator):
    """
    Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
    Checks if ``similarity(sentence, positive_example) < similarity(sentence, negative_example) + margin``.

    Args:
        anchors (List[str]): Sentences to check similarity to. (e.g. a query)
        positives (List[str]): List of positive sentences
        negatives (List[str]): List of negative sentences
        main_similarity_function (Union[str, SimilarityFunction], optional):
            The similarity function to use. If not specified, use cosine similarity,
            dot product, Euclidean, and Manhattan similarity. Defaults to None.
        margin (Union[float, Dict[str, float]], optional): Margins for various similarity metrics.
            If a float is provided, it will be used as the margin for all similarity metrics.
            If a dictionary is provided, the keys should be 'cosine', 'dot', 'manhattan', and 'euclidean'.
            The value specifies the minimum margin by which the negative sample should be further from
            the anchor than the positive sample. Defaults to None.
        name (str): Name for the output. Defaults to "".
        batch_size (int): Batch size used to compute embeddings. Defaults to 16.
        show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
        write_csv (bool): Write results to a CSV file. Defaults to True.
        truncate_dim (int, optional): The dimension to truncate sentence embeddings to.
            `None` uses the model's current truncation dimension. Defaults to None.
        similarity_fn_names (List[str], optional): List of similarity function names to evaluate.
            If not specified, evaluate using the ``model.similarity_fn_name``.
            Defaults to None.

    Example:
        ::

            from sentence_transformers import SentenceTransformer
            from sentence_transformers.evaluation import TripletEvaluator
            from datasets import load_dataset

            # Load a model
            model = SentenceTransformer('all-mpnet-base-v2')

            # Load a dataset with (anchor, positive, negative) triplets
            dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

            # Initialize the TripletEvaluator using anchors, positives, and negatives
            triplet_evaluator = TripletEvaluator(
                anchors=dataset[:1000]["anchor"],
                positives=dataset[:1000]["positive"],
                negatives=dataset[:1000]["negative"],
                name="all_nli_dev",
            )
            results = triplet_evaluator(model)
            '''
            TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
            Accuracy Cosine Similarity:        95.60%
            '''
            print(triplet_evaluator.primary_metric)
            # => "all_nli_dev_cosine_accuracy"
            print(results[triplet_evaluator.primary_metric])
            # => 0.956
    """

    def __init__(
        self,
        anchors: list[str],
        positives: list[str],
        negatives: list[str],
        main_similarity_function: str | SimilarityFunction | None = None,
        margin: float | dict[str, float] | None = None,
        name: str = "",
        batch_size: int = 16,
        show_progress_bar: bool = False,
        write_csv: bool = True,
        truncate_dim: int | None = None,
        similarity_fn_names: (
            list[
                Literal[
                    "cosine",
                    "euclidean",
                    "manhattan",
                    "dot",
                    "box_intersection",
                    "csdelta",
                    "box_entailment",
                ]
            ]
            | None
        ) = None,
        main_distance_function: str | SimilarityFunction | None = "deprecated",
        classifier: VectorEntailmentClassifier = None,
        classifier_diff: VectorEntailmentClassifierDiff = None,
    ):
        super().__init__()
        self.anchors = anchors
        self.positives = positives
        self.negatives = negatives
        self.name = name
        self.truncate_dim = truncate_dim

        assert len(self.anchors) == len(self.positives)
        assert len(self.anchors) == len(self.negatives)

        if main_distance_function != "deprecated" and main_similarity_function is None:
            main_similarity_function = main_distance_function
            logger.warning(
                "The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. "
                "'main_distance_function' will be removed in a future release."
            )

        self.main_similarity_function = (
            SimilarityFunction(main_similarity_function)
            if main_similarity_function
            else None
        )
        self.similarity_fn_names = similarity_fn_names or []

        if margin is None:
            self.margin = {
                "cosine": 0,
                "dot": 0,
                "manhattan": 0,
                "euclidean": 0,
                "vector_entailment": 0,
                "vector_entailment_diff": 0,
                "box_entailment": 0,
                "csdelta": 0,
            }
        elif isinstance(margin, (float, int)):
            self.margin = {
                "cosine": margin,
                "dot": margin,
                "manhattan": margin,
                "euclidean": margin,
                "box_entailment": 0,
                "csdelta": 0,
            }
        elif isinstance(margin, dict):
            self.margin = {
                **{
                    "cosine": 0,
                    "dot": 0,
                    "manhattan": 0,
                    "euclidean": 0,
                    "box_entailment": 0,
                    "vector_entailment": 0,
                    "vector_entailment_diff": 0,
                    "csdelta": 0,
                },
                **margin,
            }
        else:
            raise ValueError(
                "`margin` should be a float or a dictionary with keys 'cosine', 'dot', 'manhattan', and 'euclidean'"
            )

        self.batch_size = batch_size
        if show_progress_bar is None:
            show_progress_bar = (
                logger.getEffectiveLevel() == logging.INFO
                or logger.getEffectiveLevel() == logging.DEBUG
            )
        self.show_progress_bar = show_progress_bar

        self.csv_file: str = (
            "triplet_evaluation" + ("_" + name if name else "") + "_results.csv"
        )
        self.csv_headers = ["epoch", "steps"]
        self.write_csv = write_csv

        self._append_csv_headers(self.similarity_fn_names)
        self.classifier = classifier
        self.classifier_diff = classifier_diff

    def _append_csv_headers(self, similarity_fn_names):
        for fn_name in similarity_fn_names:
            self.csv_headers.append(f"accuracy_{fn_name}")

    @classmethod
    def from_input_examples(cls, examples: list[InputExample], **kwargs):
        anchors = []
        positives = []
        negatives = []

        for example in examples:
            anchors.append(example.texts[0])
            positives.append(example.texts[1])
            negatives.append(example.texts[2])
        return cls(anchors, positives, negatives, **kwargs)

    def __call__(
        self,
        model: SentenceTransformer,
        output_path: str = None,
        epoch: int = -1,
        steps: int = -1,
    ) -> dict[str, float]:
        if epoch != -1:
            if steps == -1:
                out_txt = f" after epoch {epoch}"
            else:
                out_txt = f" in epoch {epoch} after {steps} steps"
        else:
            out_txt = ""
        if self.truncate_dim is not None:
            out_txt += f" (truncated to {self.truncate_dim})"

        logger.info(
            f"TripletEvaluator: Evaluating the model on the {self.name} dataset{out_txt}:"
        )

        with (
            nullcontext()
            if self.truncate_dim is None
            else model.truncate_sentence_embeddings(self.truncate_dim)
        ):
            embeddings_anchors = model.encode(
                self.anchors,
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor=True,
            )
            embeddings_positives = model.encode(
                self.positives,
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor=True,
            )
            embeddings_negatives = model.encode(
                self.negatives,
                batch_size=self.batch_size,
                show_progress_bar=self.show_progress_bar,
                convert_to_tensor=True,
            )

        if not self.similarity_fn_names:
            self.similarity_fn_names = [model.similarity_fn_name]
            self._append_csv_headers(self.similarity_fn_names)

        similarity_functions = {
            "cosine": lambda anchors, positives, negatives: (
                pairwise_cos_sim(anchors, positives),
                pairwise_cos_sim(anchors, negatives),
            ),
            "dot": lambda anchors, positives, negatives: (
                pairwise_dot_score(anchors, positives),
                pairwise_dot_score(anchors, negatives),
            ),
            "manhattan": lambda anchors, positives, negatives: (
                pairwise_manhattan_sim(anchors, positives),
                pairwise_manhattan_sim(anchors, negatives),
            ),
            "euclidean": lambda anchors, positives, negatives: (
                pairwise_euclidean_sim(anchors, positives),
                pairwise_euclidean_sim(anchors, negatives),
            ),
            "box_entailment": lambda anchors, positives, negatives: (
                similarity_function_entailment_pairwise(
                    anchors, positives, volume_temp=1.0, intersection_temp=0.001
                ),
                similarity_function_entailment_pairwise(
                    anchors, negatives, volume_temp=1.0, intersection_temp=0.001
                ),
            ),
            "vector_entailment": lambda anchors, positives, negatives: (
                self.classifier.get_scores(anchors, positives),
                self.classifier.get_scores(anchors, negatives),
            ),
            "vector_entailment_diff": lambda anchors, positives, negatives: (
                self.classifier_diff.get_scores(anchors, positives),
                self.classifier_diff.get_scores(anchors, negatives),
            ),
            "csdelta": lambda anchors, positives, negatives: (
                vector_entailment_similarity_csdelta(anchors, positives),
                vector_entailment_similarity_csdelta(anchors, negatives),
            ),
        }

        metrics = {}
        for fn_name in self.similarity_fn_names:
            if fn_name in similarity_functions:
                positive_scores, negative_scores = similarity_functions[fn_name](
                    embeddings_anchors, embeddings_positives, embeddings_negatives
                )
                accuracy = (
                    (positive_scores > negative_scores + self.margin[fn_name])
                    .float()
                    .mean()
                    .item()
                )
                metrics[f"{fn_name}_accuracy"] = accuracy
                logger.info(
                    f"Accuracy {fn_name.capitalize()} Similarity:\t{accuracy:.2%}"
                )

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            if not os.path.isfile(csv_path):
                with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow(self.csv_headers)
                    writer.writerow([epoch, steps] + list(metrics.values()))

            else:
                with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, steps] + list(metrics.values()))

        if len(self.similarity_fn_names) > 1:
            metrics["max_accuracy"] = max(metrics.values())

        if self.main_similarity_function:
            self.primary_metric = {
                SimilarityFunction.COSINE: "cosine_accuracy",
                SimilarityFunction.DOT_PRODUCT: "dot_accuracy",
                SimilarityFunction.EUCLIDEAN: "euclidean_accuracy",
                SimilarityFunction.MANHATTAN: "manhattan_accuracy",
                # SimilarityFunction.MANHATTAN: "manhattan_accuracy",
            }.get(self.main_similarity_function)
        else:
            if len(self.similarity_fn_names) > 1:
                self.primary_metric = "max_accuracy"
            else:
                self.primary_metric = f"{self.similarity_fn_names[0]}_accuracy"

        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics

    def get_config_dict(self):
        config_dict = {}
        if self.margin != {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}:
            config_dict["margin"] = self.margin
        if self.truncate_dim is not None:
            config_dict["truncate_dim"] = self.truncate_dim
        return config_dict
