"""Analysis based on labeling HANS examples by words, heuristic, etc."""
import dataclasses
from typing import Callable, Dict, Sequence, Tuple, Union

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.util import flat_pack

from . import hans_labeling


# typedefs
AnalysisContainer = am.PefNmfAnalysisContainer
HansIndicator = hans_labeling.HansIndicator
TCIs = Sequence['TunedComponentInfo']


###############################################################################

@dataclasses.dataclass
class TunedComponentInfo:
    nmf_index: int
    component_index: int
    top_example_indices: np.ndarray


class HansTuningInfo:
    """Information about the tuning of components to HansIndicators."""

    def __init__(
        self,
        container: AnalysisContainer,
        tci_dict: Dict[str, Union[TCIs, Dict[str, TCIs]]],
    ):
        self.container = container
        self.tci_dict = tci_dict
        self.tcis_by_nmf_dict = self._make_tcis_by_nmf_dict()

    #################################################################

    def _make_tcis_by_nmf_dict(self):
        dikt = {}
        for key, tcis in self.iterate_over_tuned_component_infos():
            _set_item_tupled_key(dikt, key, _group_by_nmf(self.container, tcis))
        return dikt

    #################################################################

    def get_tci_by_nmf(self, key: Tuple[str, ...]):
        return _get_item_tupled_key(self.tcis_by_nmf_dict, key)

    #################################################################

    def iterate_over_tuned_component_infos(self):
        for k1, v1 in self.tci_dict.items():
            if v1 is None:
                continue
            elif isinstance(v1, dict):
                for k2, v2 in v1.items():
                    yield (k1, k2), v2
            else:
                yield (k1,), v1

    #################################################################

    def get_tuned_component_indices(self, key):
        tcis_by_subset, = _get_item_tupled_key(self.tcis_by_nmf_dict, key)
        return [t.component_index for t in tcis_by_subset]

    def get_not_tuned_component_indices(self, key):
        tuned_inds = self.get_tuned_component_indices(key)
        n_components = self.container.nmfs[0].W.shape[-1]
        return list(sorted(set(range(n_components)) - set(tuned_inds)))

    #################################################################

    def make_fisher_for_components(
        self,
        variables_by_subset: Sequence[Sequence[tf.Tensor]],
        filter_fn: Callable[[int, int], bool]
    ) -> Sequence[tf.Tensor]:
        # filter_fn(nmf_index, component_index)
        assert len(variables_by_subset) == self.container.n_nmfs
        
        fishers = []

        for nmf_index in range(self.container.n_nmfs):
            nmf = self.container.nmfs[nmf_index]
            subset_variables = variables_by_subset[nmf_index]
            packer = flat_pack.FlatPacker([v.shape for v in subset_variables])

            component_indices = np.array([
                filter_fn(nmf_index, component_index)
                for component_index in range(nmf.W.shape[-1])
            ], dtype=np.int32)

            # W.shape = [n_examples, n_components]
            # H.shape = [n_components, n_features]
            W = _get_denormalized_coefficients(self.container, nmf_index, component_indices)

            # Fisher
            fisher_values = tf.einsum('ij,jk -> k', W, nmf.H[component_indices]).numpy()

            full_fisher = np.zeros([nmf.full_dense_size], dtype=np.float32)
            full_fisher[nmf.reduce_kept_indices] = fisher_values

            subset_fishers = packer.decode_tf(full_fisher)
            fishers.extend(subset_fishers)

        return fishers

    # TODO: Reduce code duplication in the following.

    def make_fisher_for_tuned_components(
            self, key: Sequence[str], variables_by_subset: Sequence[Sequence[tf.Tensor]]) -> Sequence[tf.Tensor]:
        assert len(variables_by_subset) == self.container.n_nmfs

        tcis_by_subset = _get_item_tupled_key(self.tcis_by_nmf_dict, key)

        fishers = []

        for i in range(self.container.n_nmfs):
            nmf = self.container.nmfs[i]
            subset_variables = variables_by_subset[i]
            packer = flat_pack.FlatPacker([v.shape for v in subset_variables])
            component_indices = np.array([t.component_index for t in tcis_by_subset[i]], dtype=np.int32)

            # W.shape = [n_examples, n_components]
            # H.shape = [n_components, n_features]
            W = _get_denormalized_coefficients(self.container, i, component_indices)

            # Fisher
            fisher_values = tf.einsum('ij,jk -> k', W, nmf.H[component_indices]).numpy()

            full_fisher = np.zeros([nmf.full_dense_size], dtype=np.float32)
            full_fisher[nmf.reduce_kept_indices] = fisher_values

            subset_fishers = packer.decode_tf(full_fisher)
            fishers.extend(subset_fishers)
        
        return fishers

    def make_fisher_for_all_but_tuned_components(
            self, key: Sequence[str], variables_by_subset: Sequence[Sequence[tf.Tensor]]) -> Sequence[tf.Tensor]:
        assert len(variables_by_subset) == self.container.n_nmfs

        tcis_by_subset = _get_item_tupled_key(self.tcis_by_nmf_dict, key)

        fishers = []
        
        for i in range(self.container.n_nmfs):
            nmf = self.container.nmfs[i]
            subset_variables = variables_by_subset[i]
            packer = flat_pack.FlatPacker([v.shape for v in subset_variables])

            tuned_component_indices = set(t.component_index for t in tcis_by_subset[i])
            component_indices = np.array(
                [j for j in range(nmf.H.shape[0]) if j not in tuned_component_indices],
                dtype=np.int32)

            # W.shape = [n_examples, n_components]
            # H.shape = [n_components, n_features]
            W = _get_denormalized_coefficients(self.container, i, component_indices)

            # Fisher
            fisher_values = tf.einsum('ij,jk -> k', W, nmf.H[component_indices]).numpy()

            full_fisher = np.zeros([nmf.full_dense_size], dtype=np.float32)
            full_fisher[nmf.reduce_kept_indices] = fisher_values

            subset_fishers = packer.decode_tf(full_fisher)
            fishers.extend(subset_fishers)
        
        return fishers

    def make_full_fisher(self, variables_by_subset: Sequence[Sequence[tf.Tensor]]) -> Sequence[tf.Tensor]:
        assert len(variables_by_subset) == self.container.n_nmfs

        fishers = []
        
        for i in range(self.container.n_nmfs):
            nmf = self.container.nmfs[i]
            subset_variables = variables_by_subset[i]
            packer = flat_pack.FlatPacker([v.shape for v in subset_variables])

            # W.shape = [n_examples, n_components]
            # H.shape = [n_components, n_features]
            W = _get_denormalized_coefficients(self.container, i, list(range(nmf.H.shape[0])))

            # Fisher
            fisher_values = tf.einsum('ij,jk -> k', W, nmf.H).numpy()

            full_fisher = np.zeros([nmf.full_dense_size], dtype=np.float32)
            full_fisher[nmf.reduce_kept_indices] = fisher_values

            subset_fishers = packer.decode_tf(full_fisher)
            fishers.extend(subset_fishers)
        
        return fishers

###############################################################################


def get_pef_fisher(
    pef,
    variables: Sequence[tf.Tensor],
    *,
    denormalize: bool = True,
    use_tqdm: bool = True,
) -> Sequence[tf.Tensor]:
    packer = flat_pack.FlatPacker([v.shape for v in variables])

    # Shape = [n_examples, n_values]
    values = pef.fishers
    if denormalize:
        values = values * pef.dense_fisher_norms[:, None]

    indices = pef.fisher_indices

    example_inds = range(values.shape[0])
    if use_tqdm:
        example_inds = tqdm(example_inds)

    # NOTE: IDK how fast this approach is is.
    fisher = np.zeros([pef.fishers_dense_shape()[-1]], dtype=np.float32)
    for i in tqdm(range(values.shape[0]), disable=not use_tqdm):
        fisher[indices[i]] += values[i]
    fisher /= float(values.shape[0])

    return packer.decode_tf(fisher)


###############################################################################

def _get_denormalized_coefficients(
    container: AnalysisContainer,
    nmf_index: int,
    component_indices: Sequence[int],
):
    """Denormalization of components.

    The matrix V decomposed by NMF has each example's Fisher divided by the
    norm of the full Fisher. When we compute the NMF for a subset of parameters,
    we just take a subset of unit-normalized Fisher.

    When computing the diagonal approximation to the Fisher, we sum across the
    un-normalized per-example Fishers. Hence when computing the Fisher corresponding
    to a subset of components, we should multiply each example's coefficients in W
    by the magnitude of the full Fisher before summing.
    """
    nmf = container.nmfs[nmf_index]
    component_indices = np.array(component_indices, dtype=np.int32)
    norms = container.pef.dense_fisher_norms[:nmf.W.shape[0], None]
    # W.shape = [n_examples, n_components]
    W = nmf.W[:, component_indices]
    return norms * W


def _group_by_nmf(container: AnalysisContainer, infos: TCIs):
    return [
        [f for f in infos if f.nmf_index == i]
        for i in range(container.n_nmfs)
    ]


def _get_item_tupled_key(dikt, key):
    if isinstance(key, str):
        raise ValueError('Please pass a tuple as a key instead of a string.')

    if len(key) == 0:
        raise ValueError('Empty key tuple passed.')
    elif len(key) == 1:
        return dikt[key[0]]
    else:
        return _get_item_tupled_key(dikt[key[0]], key[1:])


def _set_item_tupled_key(dikt, key, value):
    if isinstance(key, str):
        raise ValueError('Please pass a tuple as a key instead of a string.')

    if len(key) == 0:
        raise ValueError('Empty key tuple passed.')
    elif len(key) == 1:
        dikt[key[0]] = value
    else:
        if key[0] not in dikt:
            dikt[key[0]] = {}
        _set_item_tupled_key(dikt[key[0]], key[1:], value)


###############################################################################


def compute_hans_tuning_info(
    container,
    indicators: HansIndicator,
    selection_parameters: ncf.SelectionParameters,
    *,
    use_tqdm: bool = True,
) -> HansTuningInfo:
    p = selection_parameters
    named_indicators = list(indicators.iterate_over_indicators())
    
    top_examples_by_nmf = []
    for nmf_index in range(container.n_nmfs):
        top_examples_by_nmf.append([])
        n_components = container.nmfs[nmf_index].W.shape[-1]
        for component_index in range(n_components):
            top_examples = container.get_top_examples_based_on_relative_coefficient(
                nmf_index=nmf_index,
                component_index=component_index,
                factor=p.coeff_factor,
                max_examples=p.max_examples,
            )
            top_example_indices = np.array([e.index for e in top_examples], dtype=np.int32)
            top_examples_by_nmf[nmf_index].append(top_example_indices)

    name_tuple_to_p_true = {}
    for name_tuple, indicator in named_indicators:
        p_true = ncf._get_p_true(container, indicator, ignore_unlabeled=False)
        name_tuple_to_p_true[name_tuple] = p_true
    
    dikt = {}
    if use_tqdm:
        named_indicators = tqdm(named_indicators)

    for name_tuple, indicator in named_indicators:
        #
        infos_placeholder = []
        #
        for nmf_index, nmf_top_examples in enumerate(top_examples_by_nmf):
            for component_index, top_example_indices in enumerate(nmf_top_examples):
                n_examples = len(top_example_indices)
                if n_examples <= 1:
                    continue
                p_true = name_tuple_to_p_true[name_tuple]
                n_true_for_component = indicator[top_example_indices].astype(np.int32).sum()
                
                fraction = n_true_for_component / n_examples
                p_value = am._binomial_pmf(n_examples, np.arange(n_true_for_component, n_examples + 1), p_true).sum()
                
                if fraction >= p.frac_threshold and p_value <= p.p_value_threshold:
                    # Component is tuned.
                    tc_info = TunedComponentInfo(
                        nmf_index=nmf_index,
                        component_index=component_index,
                        top_example_indices=top_example_indices,
                    )
                    infos_placeholder.append(tc_info)
        
        _set_item_tupled_key(dikt, name_tuple, infos_placeholder)
    
    return HansTuningInfo(container, dikt)
