"""Utilities for ablating multiple components."""
import dataclasses
from typing import Optional, Sequence, Tuple

import numpy as np
import tensorflow as tf

from em.tools.nmf import nmf_common
from em.util import sparse_util

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import guided_ablations


Nmf = nmf_common.SparseNmfDecomposition


##########################################################################

@tf.function
def _sp_mv_mul(H, vec):
    return tf.sparse.sparse_dense_matmul(
        H, tf.sparse.to_dense(vec)[:, None]
    )


def compute_H_cos_sim_matrix(nmf: nmf_common.SparseNmfDecomposition):
    # Assumes that the H has been normalized to unit norm.
    Hs = nmf.get_full_sparse_H()
    H = sparse_util.stack_as_rows(Hs)
    # NOTE: Can probably speed up by doing the sp-dense matmul with multiple
    # components at once.
    return tf.concat([
        _sp_mv_mul(H, h) for h in Hs
    ], axis=-1).numpy()


def compute_W_cos_sim_matrix(nmf: nmf_common.SparseNmfDecomposition):
    W = tf.cast(nmf.W, tf.float32)
    # Normalize.
    inv_norms = tf.math.rsqrt(tf.reduce_sum(tf.square(W), axis=0, keepdims=True))
    W *= inv_norms
    return tf.linalg.matmul(W, W, transpose_a=True).numpy()


##########################################################################

@dataclasses.dataclass
class TopExampleIndices:
    # shape = [n_selected_comps]
    component_indices: np.ndarray

    # shape = [n_selected_comps, n_ex_per_component]
    example_indices_by_component: np.ndarray

    _unique_example_indices: Optional[np.ndarray] = None

    @property
    def unique_example_indices(self) -> np.ndarray:
        if self._unique_example_indices is None:
            inds_set = set(self.example_indices_by_component.reshape([-1]))
            self._unique_example_indices = np.array(list(sorted(inds_set)), dtype=np.int32)
        return self._unique_example_indices

    @classmethod
    def select(cls, nmf: Nmf, component_indices: Sequence[int], n_ex_per_component: int):
        # NOTE: The n_ex_per_component includes potential duplicates amongst components.
        component_indices = np.array(component_indices, dtype=np.int32)
        ex_by_comp = np.stack([
            np.argsort(-nmf.W[:, comp_ind])[:n_ex_per_component]
            for comp_ind in component_indices
        ], axis=0)
        return cls(
            component_indices=component_indices,
            example_indices_by_component=ex_by_comp)


##########################################################################

@dataclasses.dataclass
class ExperimentHelper1:
    exp: ablation_exp_util.Experiment1

    top_examples_info: TopExampleIndices

    kl_target_range: Tuple[float, float]

    # Maybe can be "kl", "loss"
    gradient_target: str = "kl"

    # Whether to ablate in the direction of decreasing or increasing gradient.
    negate_gradient: bool = False

    ablating_variable_style: str = "fixed_offset"

    def __post_init__(self):
        assert self.gradient_target in guided_ablations.GRADIENT_TARGETS
        assert self.ablating_variable_style in guided_ablations.ABLATING_VARIABLE_STYLES

        self.eval_ctx = self.exp.mc.get_evaluation_context()

        self.top_example_inds = self.top_examples_info.unique_example_indices
        self.top_examples_fisher = self._compute_dense_fisher(self.top_example_inds)
        self.top_examples_gradient = self._compute_normalized_gradient(self.top_example_inds)

    ######################################################################

    def _compute_dense_fisher(self, example_inds: np.ndarray):
        return self.exp.compute_fisher(example_inds)

    def _compute_normalized_gradient(self, example_inds: np.ndarray):
        exp = self.exp

        if self.gradient_target == 'loss':
            grads = exp.compute_loss_gradient(example_inds)
        elif self.gradient_target == 'kl':
            grads = exp.compute_kl_gradient(example_inds, allow_recompile=True)
        else:
            raise ValueError(self.gradient_target)

        if self.negate_gradient:
            grads = [-g for g in grads]

        # Normalization.
        inv_norm = tf.math.rsqrt(tf.reduce_sum([tf.reduce_sum(tf.square(g)) for g in grads]))
        grads = [inv_norm * g for g in grads]

        return grads

    ######################################################################

    def get_component_examples_ablator(self) -> guided_ablations.Ablator1:
        return guided_ablations.Ablator1(
            helper=self,
            example_inds=self.top_example_inds,
            examples_fisher=self.top_examples_fisher,
            examples_gradient=self.top_examples_gradient,
        )

    def get_component_H_ablator(self) -> guided_ablations.Ablator1:
        return guided_ablations.Ablator1(
            helper=self,
            example_inds=self.top_example_inds,
            examples_fisher=self.exp.mc.make_fisher_for_components(self.top_examples_info.component_indices),
            examples_gradient=self.top_examples_gradient,
        )
