"""Determining top examples given a fixed set of cluster assignments and centroid distances."""
import dataclasses
from typing import Dict, List, Optional

import numpy as np

from npeff_torch.examination.top_examples import top_examples_common


###############################################################################
TopExamplesReaderAbc = top_examples_common.TopExamplesReaderAbc
TopExampleInfo = top_examples_common.TopExampleInfo
TopLogProbs = top_examples_common.TopLogProbs
###############################################################################


@dataclasses.dataclass
class TopExamplesReaderFromClusters(TopExamplesReaderAbc):
    
    # shape = [n_samples], dtype=np.int64
    cluster_assignments: np.ndarray

    # shape = [n_samples]
    centroid_distances: np.ndarray

    examples: Dict[str, np.ndarray]

    labels: Optional[np.ndarray] = None
    logits: Optional[np.ndarray] = None

    top_log_probs_class_indices: Optional[np.ndarray] = None
    top_log_probs_values: Optional[np.ndarray] = None

    token_positions: Optional[np.ndarray] = None

    n_components: Optional[int] = None

    #######################################################

    def _get_top_example_indices(self, component_index: int, n_top_examples: int) -> Optional[np.ndarray]:
        cluster_example_indices, = np.nonzero(self.cluster_assignments == component_index)
        if not len(cluster_example_indices):
            return None

        cluster_centroid_distances = self.centroid_distances[cluster_example_indices]
        return cluster_example_indices[np.argsort(cluster_centroid_distances)[:n_top_examples]]

    #######################################################

    def get_top_examples_for_component(self, component_index: int, n_top_examples: int) -> List[TopExampleInfo]:
        ex_inds = self._get_top_example_indices(component_index, n_top_examples)
        if ex_inds is None:
            return []

        maybe_get = lambda a, i: None if a is None else a[i]
        
        ret = []
        for i in ex_inds:
            if self.top_log_probs_class_indices is not None:
                top_log_probs = TopLogProbs(
                    class_indices=self.top_log_probs_class_indices[i],
                    values=self.top_log_probs_values[i],
                )
            else:
                top_log_probs = None
                
            ret.append(TopExampleInfo(
                coefficient=self.centroid_distances[i],
                example={k: v[i] for k, v in self.examples.items()},
                label=maybe_get(self.labels, i),
                logits=maybe_get(self.logits, i),
                top_log_probs=top_log_probs,
                token_position=maybe_get(self.token_positions, i),
            ))

        return ret
