"""Helpers for exploring the top component examples for NMF."""
import dataclasses
import os
from typing import List, Sequence, Union

import numpy as np
import tensorflow as tf

from em import datasets as em_datasets

from em.fishers import per_example
from em.fishers import lrm_pefs
from em.tools.nmf import nmf_common
from em.tools.nmf import lrm_npeff


@dataclasses.dataclass
class ImageNetExample:
    label: int
    image: Union[np.ndarray, None]


@dataclasses.dataclass
class ImageNetCC:
    """The CC stands for components context."""

    pef_filepath: str
    nmf_filepath: str

    ds_split: str
    ds_offset: int = 0

    include_images: bool = True

    def __post_init__(self):
        self.pef = per_example.PerExampleFlatFishers.load(
            self.pef_filepath,
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        self.nmf = nmf_common.SparseNmfDecomposition.load(self.nmf_filepath)
        self.nmf.normalize_components_to_unit_norm()

        # TODO: Allow us to specify othe subsets of example indices.
        if self.pef.input_ids.shape[0] > self.nmf.W.shape[0]:
            self.pef = self.pef.create_for_subset(list(range(self.nmf.W.shape[0])))

        self.examples = self._load_examples()

    def _load_examples(self):
        ds = em_datasets.load('imagenet/default', split=self.ds_split, tokenizer=None, sequence_length=224)
        ds = ds.skip(self.ds_offset)
        if not self.include_images:
            ds = ds.map(lambda x, y: y)
        ds = ds.take(self.nmf.W.shape[0]).cache().as_numpy_iterator()
        return [
            ImageNetExample(image=b[0] / 255.0, label=b[1])
            if self.include_images
            else ImageNetExample(image=None, label=b)
            for b in ds
        ]
    ############################################################

    @property
    def n_components(self) -> int:
        return self.nmf.W.shape[-1]

    ############################################################

    def get_examples_by_indices(self, example_indices: Sequence[int]) -> List[ImageNetExample]:
        return [self.examples[i] for i in example_indices]

    ############################################################

    def get_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        return np.argsort(-self.nmf.W[:, component_index])[:n_examples]
    
    def get_top_examples(self, component_index: int, n_examples: int) -> List[ImageNetExample]:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.get_examples_by_indices(inds)

    def get_top_example_labels(self, component_index: int, n_examples: int) -> np.ndarray:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.pef.labels[inds]

    def get_top_example_logits(self, component_index: int, n_examples: int) -> np.ndarray:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.pef.predicted_logits[inds]


###############################################################################
###############################################################################


@dataclasses.dataclass
class LrmImageNetCC:
    """The CC stands for components context."""

    pef_filepath: str
    nmf_filepath: str

    ds_split: str
    ds_offset: int = 0

    include_images: bool = True

    def __post_init__(self):
        self.labels = lrm_pefs.SparseLrmPefs.load_logits(self.pef_filepath)
        self.logits = lrm_pefs.SparseLrmPefs.load_labels(self.pef_filepath)

        self.nmf = lrm_npeff.LrmNpeffDecomposition.load(self.nmf_filepath, read_G=False)

        # TODO: Allow us to specify othe subsets of example indices.
        if self.labels.shape[0] > self.nmf.W.shape[0]:
            self.labels = self.labels[:self.nmf.W.shape[0]]
            self.logits = self.logits[:self.nmf.W.shape[0]]

        self.examples = self._load_examples()

    def _load_examples(self):
        ds = em_datasets.load('imagenet/default', split=self.ds_split, tokenizer=None, sequence_length=224)
        ds = ds.skip(self.ds_offset)
        if not self.include_images:
            ds = ds.map(lambda x, y: y)
        ds = ds.take(self.nmf.W.shape[0]).cache().as_numpy_iterator()
        return [
            ImageNetExample(image=b[0] / 255.0, label=b[1])
            if self.include_images
            else ImageNetExample(image=None, label=b)
            for b in ds
        ]
    ############################################################

    @property
    def n_components(self) -> int:
        return self.nmf.W.shape[-1]

    ############################################################

    def get_examples_by_indices(self, example_indices: Sequence[int]) -> List[ImageNetExample]:
        return [self.examples[i] for i in example_indices]

    ############################################################

    def get_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        return np.argsort(-self.nmf.W[:, component_index])[:n_examples]
    
    def get_top_examples(self, component_index: int, n_examples: int) -> List[ImageNetExample]:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.get_examples_by_indices(inds)

    def get_top_example_labels(self, component_index: int, n_examples: int) -> np.ndarray:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.labels[inds]

    def get_top_example_logits(self, component_index: int, n_examples: int) -> np.ndarray:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.logits[inds]
