import numpy as np

from typing import Dict

from .estimator import Estimator


class SemanticDensity(Estimator):

    def __init__(
            self,
            verbose: bool = False,
            concat_input: bool = True,
            samples_source: str = "sample",
    ):
        deps = [
            "greedy_log_probs",
            f"{samples_source}_log_probs",
            f"{samples_source}_tokens",
            f"{samples_source}_texts",
        ]
        if concat_input:
            deps.extend(
                [
                    f"concat_greedy_{samples_source}_semantic_matrix_contra_forward",
                    f"concat_greedy_{samples_source}_semantic_matrix_neutral_forward",
                ]
            )
        else:
            deps.extend(
                [
                    f"greedy_{samples_source}_semantic_matrix_contra_forward",
                    f"greedy_{samples_source}_semantic_matrix_neutral_forward",
                ]
            )
        super().__init__(deps, "sequence")
        self.verbose = verbose
        self.concat_input = concat_input
        self.samples_source = samples_source

    def __str__(self):
        if self.concat_input:
            base = "SemanticDensity"
        else:
            base = "SemanticDensityOutput"
        if self.samples_source != "sample":
            base += f'_{self.samples_source}'
        return base

    def __call__(self, stats: Dict[str, np.ndarray]) -> np.ndarray:
        batch_sample_log_probs = stats[f"{self.samples_source}_log_probs"]
        batch_sample_tokens = stats[f"{self.samples_source}_tokens"]
        batch_sample_texts = stats[f"{self.samples_source}_texts"]
        batch_greedy_log_likelihoods = stats["greedy_log_likelihoods"]

        if self.concat_input:
            batch_semantic_matrix_contra = stats[
                f"concat_greedy_{self.samples_source}_semantic_matrix_contra_forward"
            ]
            batch_semantic_matrix_neutral = stats[
                f"concat_greedy_{self.samples_source}_semantic_matrix_neutral_forward"
            ]
        else:
            batch_semantic_matrix_contra = stats[
                f"greedy_{self.samples_source}_semantic_matrix_contra_forward"
            ]
            batch_semantic_matrix_neutral = stats[
                f"greedy_{self.samples_source}_semantic_matrix_neutral_forward"
            ]

        semantic_density = []
        for batch_data in zip(
            batch_greedy_log_likelihoods,
            batch_sample_log_probs,
            batch_sample_tokens,
            batch_sample_texts,
            batch_semantic_matrix_contra,
            batch_semantic_matrix_neutral,
        ):
            greedy_log_probs = batch_data[0]
            sample_probs = np.exp(batch_data[1])
            sample_tokens = batch_data[2]
            sample_texts = batch_data[3]
            semantic_matrix_contra = batch_data[4]
            semantic_matrix_neutral = batch_data[5]

            _, unique_sample_indices = np.unique(sample_texts, return_index=True)

            numerator, denominator = [], []

            for _id in unique_sample_indices:
                normed_prob = sample_probs[_id] ** (1 / len(sample_tokens[_id]))
                distance = semantic_matrix_contra[_id] + (
                    semantic_matrix_neutral[_id] / 2
                )

                if distance <= 1:
                    kernel_value = 1 - distance
                else:
                    kernel_value = 0

                numerator.append(normed_prob * kernel_value)
                denominator.append(normed_prob)

            greedy_normed_prob = np.exp(np.sum(greedy_log_probs)) ** (
                1 / len(greedy_log_probs)
            )
            numerator.append(greedy_normed_prob)
            denominator.append(greedy_normed_prob)

            semantic_density.append(np.sum(numerator) / np.sum(denominator))

        return -np.array(semantic_density)
