"""Generic container for ResNet decompositions."""
import dataclasses
import os
from typing import List, Sequence, Union

import numpy as np
import tensorflow as tf

from em import datasets as em_datasets


@dataclasses.dataclass
class ImageNetExample:
    label: int
    image: Union[np.ndarray, None]


@dataclasses.dataclass
class ResnetContainer:
    # shape = [n_examples, n_components]
    coeffs: np.ndarray

    ds_split: str

    n_top_examples: int

    def __post_init__(self):
        self.examples = self._load_examples()

    def _load_examples(self):
        ds = em_datasets.load('imagenet/default', split=self.ds_split, tokenizer=None, sequence_length=224)
        ds = ds.take(self.coeffs.shape[0]).cache().as_numpy_iterator()
        return [
            ImageNetExample(image=b[0] / 255.0, label=b[1])
            for b in ds
        ]

    ############################################################

    @property
    def n_components(self) -> int:
        return self.coeffs.shape[-1]

    ############################################################

    def get_examples_by_indices(self, example_indices: Sequence[int]) -> List[ImageNetExample]:
        return [self.examples[i] for i in example_indices]

    ############################################################

    def get_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        return np.argsort(-self.coeffs[:, component_index])[:n_examples]
    
    def get_top_examples(self, component_index: int, n_examples: int) -> List[ImageNetExample]:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.get_examples_by_indices(inds)
