"""Common code, ABCs, and interfaces for top examples stuff."""
import abc
import dataclasses
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch


@dataclasses.dataclass
class TopLogProbs:
    """Information about the top subset of log probs for an example."""
    # shape = [n_log_probs], dtype=np.int32
    class_indices: np.ndarray

    # Expected to be sorted in descending order.
    # shape = [n_log_probs], dtype=np.float32
    values: np.ndarray

    @property
    def n_log_probs(self) -> int:
        return self.class_indices.shape[-1]


@dataclasses.dataclass
class TopExampleInfo:
    """Information about a top example."""

    # The coefficient of the associated component for this example.
    coefficient: float

    example: Dict[str, np.ndarray]

    label: Optional[int] = None

    # Usually, only at most one of logits or top_log_probs are assumed to be present.
    logits: Optional[np.ndarray] = None
    top_log_probs: Optional[TopLogProbs] = None

    token_position: Optional[int] = None

    def get_top_probs(self, n_top_probs: int) -> List[Tuple[int, float]]:
        if self.logits is not None:
            probs = torch.softmax(torch.from_numpy(self.logits), dim=-1).numpy()
            class_indices = np.argsort(-self.logits)[:n_top_probs]
            return [(int(ci), float(probs[ci])) for ci in class_indices]

        elif self.top_log_probs is not None:
            if n_top_probs > self.top_log_probs.n_log_probs:
                raise ValueError('Requested more probabilities than were provided.')
            return [(int(self.top_log_probs.class_indices[i]), float(np.exp(self.top_log_probs.values[i]))) for i in range(n_top_probs)]
            
        else:
            raise ValueError('Cannot get top probs if both logits and top_log_probs are None.')

    def get_prediction(self) -> int:
        if self.logits is not None:
            return int(np.argmax(self.logits))
        elif self.top_log_probs is not None:
            return int(self.top_log_probs.class_indices[0])
        else:
            raise ValueError('Cannot get prediction if both logits and top_log_probs are None.')

    def prediction_is_correct(self) -> bool:
        if self.label is None:
            raise ValueError('Cannot see if the prediction is correct when the label is None.')
        return bool(self.label == self.get_prediction())

    def has_same_example(self, other: 'TopExampleInfo') -> bool:
        if set(self.example.keys()) != set(other.example.keys()):
            return False
        return all(np.all(self.example[k] == other.example[k]) for k in self.example.keys())


class TopExamplesReaderAbc(abc.ABC):
    
    @abc.abstractmethod
    def get_top_examples_for_component(self, component_index: int, n_top_examples: int) -> List[TopExampleInfo]:
        """Get the top examples for the given component.
        
        The returned list should be sorted in DESCENDING order of component coefficient magnitude.
        """
        raise NotImplementedError

    @classmethod
    def create(cls, **kwargs):
        return cls(**kwargs)