"""Determining top examples given a fixed coefficients matrix."""
import dataclasses
from typing import Dict, List, Optional

import numpy as np

from npeff_torch.examination.top_examples import top_examples_common


###############################################################################
TopExamplesReaderAbc = top_examples_common.TopExamplesReaderAbc
TopExampleInfo = top_examples_common.TopExampleInfo
TopLogProbs = top_examples_common.TopLogProbs
###############################################################################


@dataclasses.dataclass
class TopExamplesReaderFromCoeffs(TopExamplesReaderAbc):
    # shape = [n_examples, n_components]
    coefficients: np.ndarray

    examples: Dict[str, np.ndarray]

    labels: Optional[np.ndarray] = None
    logits: Optional[np.ndarray] = None

    top_log_probs_class_indices: Optional[np.ndarray] = None
    top_log_probs_values: Optional[np.ndarray] = None

    token_positions: Optional[np.ndarray] = None
    #######################################################

    @property
    def n_examples(self) -> int:
        return self.coefficients.shape[0]
    
    @property
    def n_components(self) -> int:
        return self.coefficients.shape[1]

    #######################################################

    def _get_top_example_indices(self, component_index: int, n_top_examples: int) -> np.ndarray:
        comp_coeffs = self.coefficients[:, component_index]
        return np.argsort(-comp_coeffs)[:n_top_examples]

    #######################################################

    def make_top_example_info_by_indices(self, *, example_index: int, component_index: int) -> 'TopExampleInfo':
        maybe_get = lambda a, i: None if a is None else a[i]
        if self.top_log_probs_class_indices is not None:
            top_log_probs = TopLogProbs(
                class_indices=self.top_log_probs_class_indices[example_index],
                values=self.top_log_probs_values[example_index],
            )
        else:
            top_log_probs = None
            
        return TopExampleInfo(
            coefficient=self.coefficients[example_index, component_index],
            example={k: v[example_index] for k, v in self.examples.items()},
            label=maybe_get(self.labels, example_index),
            logits=maybe_get(self.logits, example_index),
            top_log_probs=top_log_probs,
            token_position=maybe_get(self.token_positions, example_index),
        )

    #######################################################

    def get_top_examples_for_component(self, component_index: int, n_top_examples: int) -> List[TopExampleInfo]:
        ex_inds = self._get_top_example_indices(component_index, n_top_examples)
        return [
            self.make_top_example_info_by_indices(example_index=example_index, component_index=component_index)
            for example_index in ex_inds
        ]

    #######################################################
    
    def get_unique_top_examples_for_component(self, component_index: int, n_top_examples: int) -> List[TopExampleInfo]:
        ret = []
        example_indices = np.argsort(-self.coefficients[:, component_index])
        for example_index in example_indices:
            top_example_info = self.make_top_example_info_by_indices(example_index=example_index, component_index=component_index)
            if not any(top_example_info.has_same_example(tei) for tei in ret):
                ret.append(top_example_info)
            if len(ret) >= n_top_examples:
                break
        return ret
