from typing import List
import numpy as np
from .encoder import Encoder
from .scores import normalize_scores, reverse_scores


class EmbeddingComparator():
    """
    Computes cosine similarity scores between a reference text and a list of other texts.
    Embeddings are generated using a provided Encoder.
    """

    def __init__(self, encoder: Encoder):
        """
        Initializes the EmbeddingComparator with a specific encoder.
        """
        self.__encoder = encoder

    def compare(
        self,
        reference_text: str,
        texts: List[str],
        do_normalize_scores: bool = True
    ) -> List[float]:
        """
        Alias for calculate_scores with reverse=True for backward compatibility.
        """
        return self.calculate_scores(
            reference_text=reference_text,
            texts_to_compare=texts,
            normalize=do_normalize_scores,
            reverse=True)

    def compare_eval(
        self,
        reference_text: str,
        texts: List[str],
        do_normalize_scores: bool = True
    ) -> List[float]:
        """
        Alias for calculate_scores with reverse=False for backward compatibility.
        """
        return self.calculate_scores(
            reference_text=reference_text,
            texts_to_compare=texts,
            normalize=do_normalize_scores,
            reverse=False)

    def calculate_scores(
        self,
        reference_text: str,
        texts_to_compare: List[str],
        normalize: bool = True,
        reverse: bool = False
    ) -> List[float]:
        """
        Calculates similarity scores between a reference text and a list of other texts.
        """
        if not texts_to_compare:
            return []

        # Encode all texts at once for efficiency
        all_texts = texts_to_compare + [reference_text]
        embeddings = np.array(self.__encoder.encode(texts=all_texts))

        # Separate the reference embedding from the comparison embeddings
        comparison_vectors = embeddings[:-1]
        reference_vector = embeddings[-1]

        # Calculate cosine similarity using numpy for vectorized operation
        scores = np.dot(comparison_vectors, reference_vector) / (np.linalg.norm(comparison_vectors, axis=1) * np.linalg.norm(reference_vector))

        if normalize:
            scores = normalize_scores(scores=scores)
        if reverse:
            scores = reverse_scores(scores=scores)

        return scores.tolist()