"""Common contexts around a set of perturbation runs.

Currently, this is only supported for LRM-NPEFFs.
"""
import abc
import functools
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel

from npeff_torch.compressed_sensing import compressed_sensing_common
from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.decomps.npeff import lrm_npeff_decomps
from npeff_torch.models import parameter_infos
from npeff_torch.perturbations import evaluation_contexts
from npeff_torch.perturbations import perturbation_results
from npeff_torch.util import flat_pack
###############################################################################


def _normalize_vector_(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    assert len(x.shape) == 1
    norm = torch.linalg.vector_norm(x)
    if norm < eps:
        x /= eps
    else:
        x /= norm
    return x


###############################################################################


class DecompositionWrapper(abc.ABC):
    """Wrapper around LRM-NPEFF decomposition.
    
    Contains additional information needed to go from reduced to full representation of the pseudo-Fishers.
    """

    # Must be called in constructors of subclasses
    def __init__(
        self, *,

        # The decomposition itself. Must have the `G` field as non-None. Its components must be normalized.
        decomposition: Union['lrm_npeff_decomps.LrmNpeffDecomposition', 'kmeans.KmeansClusteringTorch'],

        # When creating the perturbation for a component, its reduced pseudo-Fisher will be
        # orthogonally rejected from all other component pseudo-Fishers if the absolute value of
        # their cosine similarity is less than this value.
        #
        # Set to None to not do this.
        rejection_max_abs_cos_similarity: Optional[float],

        # Whether or not to cache the absolute values of cosine similarities. If so, this introduces
        # the storage of a tensor of size [n_components, n_components].
        cache_abs_cos_similarities: bool = True,

        # Information about the parameters that were used to compute the PEFs.
        parameter_infos: List['parameter_infos.ParameterInfo']

    ):
        if isinstance(decomposition, lrm_npeff_decomps.LrmNpeffDecomposition):
            assert decomposition.G is not None, 'The decomposition must have the pseudo-Fisher vectors attached.'
            assert decomposition.components_are_normalized, 'The decomposition must have its components be normalized.'
        elif isinstance(decomposition, kmeans.KmeansClusteringTorch):
            assert decomposition.centroids is not None, 'The decomposition must have the pseudo-Fisher vectors attached.'
            assert decomposition.components_are_normalized, 'The decomposition must have its components be normalized.'
        else:
            raise ValueError

        self.decomposition = decomposition
        self.rejection_max_abs_cos_similarity = rejection_max_abs_cos_similarity
        self.parameter_infos = tuple(parameter_infos)

        if cache_abs_cos_similarities and rejection_max_abs_cos_similarity is not None:
            # shape = [n_components, n_components]
            self._abs_cos_similarities = self._compute_abs_cos_similarities(self._get_all_pseudo_fishers())
        else:
            self._abs_cos_similarities = None

        self.flat_packer = flat_pack.FlatPacker([p.shape for p in self.parameter_infos])

    #######################################################
    # Abstract methods:

    @abc.abstractmethod
    def _up_project_reduced_pseudo_fisher(self, reduced_pseudo_fisher: torch.Tensor) -> torch.Tensor:
        # TODO: Add docs.
        raise NotImplementedError

    #######################################################

    @torch.no_grad()
    def _get_all_pseudo_fishers(self) -> torch.Tensor:
        # shape = [n_components, n_features]
        if isinstance(self.decomposition, lrm_npeff_decomps.LrmNpeffDecomposition):
            return self.decomposition.G
        elif isinstance(self.decomposition, kmeans.KmeansClusteringTorch):
            return self.decomposition.centroids
        else:
            raise ValueError

    #######################################################

    @torch.no_grad()
    def _compute_abs_cos_similarities(self, G: torch.Tensor) -> torch.Tensor:
        return torch.einsum('cf,kf->ck', G, G).abs_()

    #######################################################

    @torch.no_grad()
    def get_top_example_indices_for_component(self, component_index: int, n_top_examples: int) -> torch.Tensor:
        """Returns the indices of the top examples for a component.

        The decomposition must have the coefficients `W` attached. If fewer than `n_top_examples` have
        a non-zero coefficient for the component, then only the indices of the examples with a non-zero
        component are returned.
        """
        if isinstance(self.decomposition, lrm_npeff_decomps.LrmNpeffDecomposition):
            W = self.decomposition.W
            assert W is not None, 'decomposition must have the coefficients W attached'

            top_coeffs, top_indices = torch.topk(W[:, component_index], k=n_top_examples)
            return top_indices[top_coeffs > 0.0]

        elif isinstance(self.decomposition, kmeans.KmeansClusteringTorch):
            d = self.decomposition

            # Hopefully 1e12 is big enough. If the gradients were normalized beforehand, they should all lie
            # on the unit spehere, and all centroids should lie within the unit ball. Hence all distances should
            # be less than 2, so 1e12 should be more than big enough.
            big_dist = 1e12

            centroid_distances = d.centroid_distances.clone()
            centroid_distances[d.cluster_assignments != component_index] = big_dist

            top_coeffs, top_indices = torch.topk(centroid_distances, k=n_top_examples, largest=False)
            return top_indices[top_coeffs < big_dist]

        else:
            raise ValueError

    #######################################################

    @torch.no_grad()
    def _get_reduced_normalized_perturbation(self, component_index: int) -> torch.Tensor:
        # ret.shape = [reduced_n_features]
        G = self._get_all_pseudo_fishers()

        g = torch.clone(G[component_index])

        if self.rejection_max_abs_cos_similarity is not None:
            if self._abs_cos_similarities is not None:
                component_abs_cos_similarities = self._abs_cos_similarities[component_index]
            else:
                component_abs_cos_similarities = torch.einsum('cf,f->c', G, G[component_index]).abs_()

            for i in range(G.shape[0]):
                if i == component_index:
                    continue
                if component_abs_cos_similarities[i] > self.rejection_max_abs_cos_similarity:
                    continue
                g -= torch.dot(g, G[i]) * G[i]

        # Normalize to unit vector L2 norm
        _normalize_vector_(g)

        return g

    @torch.no_grad()
    def make_labeled_parameter_perturbations(self, component_index: int) -> List[Tuple['parameter_infos.ParameterInfo', torch.Tensor]]:
        reduced_perturbation = self._get_reduced_normalized_perturbation(component_index)
        full_perturbation = self._up_project_reduced_pseudo_fisher(reduced_perturbation)
        _normalize_vector_(full_perturbation)

        unpacked_perturbation = self.flat_packer.unpack_vector(full_perturbation)
        return list(zip(self.parameter_infos, unpacked_perturbation))


###############################################################################


class RandomlyProjectedDecompositionWrapper(DecompositionWrapper):

    def __init__(
        self, *,
        reconstructor: 'compressed_sensing_common.ReconstructorAbc',
        **other_super_kwargs,
    ):
        super().__init__(**other_super_kwargs)
        self.reconstructor = reconstructor

    def _up_project_reduced_pseudo_fisher(self, reduced_pseudo_fisher: torch.Tensor) -> torch.Tensor:
        return self.reconstructor.reconstruct_vector(reduced_pseudo_fisher)


###############################################################################
###############################################################################


class ExamplesBatchPerturbationInfo:
    """Information about the perturbations for a batch."""

    def __init__(
        self, *,
        # Assumes that the examples are consistent amongst these.
        original_batch_info: 'evaluation_contexts.EvaluationBatchInfo',
        perturbed_batch_info: 'evaluation_contexts.EvaluationBatchInfo',
    ):
        self.original_batch_info = original_batch_info
        self.perturbed_batch_info = perturbed_batch_info

    #######################################################

    @property
    @torch.no_grad()
    def original_predictions(self) -> torch.Tensor:
        return torch.argmax(self.original_batch_info.logits, dim=-1)

    @property
    @torch.no_grad()
    def perturbed_predictions(self) -> torch.Tensor:
        return torch.argmax(self.perturbed_batch_info.logits, dim=-1)
        
    @property
    @torch.no_grad()
    def original_log_probs(self) -> torch.Tensor:
        return torch.nn.functional.log_softmax(self.original_batch_info.logits, dim=-1)

    @property
    @torch.no_grad()
    def perturbed_log_probs(self) -> torch.Tensor:
        return torch.nn.functional.log_softmax(self.perturbed_batch_info.logits, dim=-1)

    #######################################################

    @functools.cached_property
    @torch.no_grad()
    def kls(self) -> torch.Tensor:
        """Tensor of type float32 and shape [n_batch_examples] containing D_KL(original || perturbed)."""
        # NOTE: The D_KL(Q||P) = torch.nn.functional.kl_div(P, Q), i.e. the arguments should be in the reversed order.
        return torch.nn.functional.kl_div(self.perturbed_log_probs, self.original_log_probs, reduction='none', log_target=True).sum(dim=-1)

    @functools.cached_property
    @torch.no_grad()
    def changed_predictions(self) -> torch.Tensor:
        """Tensor of type bool and shape [n_batch_examples] indicating changed predictions."""
        return self.original_predictions != self.perturbed_predictions


###############################################################################


class ComponentPerturber:
    """A class that performs perturbations for a single component."""

    def __init__(
        self, *,
        component_index: int,
        decomposition_wrapper: 'DecompositionWrapper',

        # These should be different copies of the same model. The original_model will
        # remain unchanged while the perturbed_model will be updated with the perturbation.
        original_model: PreTrainedModel,
        perturbed_model: PreTrainedModel,

        # The examples with the highest coefficients for the component of interest.
        top_examples: 'evaluation_contexts.Examples',
        # A baseline set of examples either with low coefficients for the component or selected at
        # random with no regard for the coefficient values.
        baseline_examples: 'evaluation_contexts.Examples',

        evaluation_batch_size: int,
    ):
        self.component_index = component_index
        self.decomposition_wrapper = decomposition_wrapper

        self.original_model = original_model
        self.perturbed_model = perturbed_model

        self.top_examples = top_examples
        self.baseline_examples = baseline_examples

        self.evaluation_batch_size = evaluation_batch_size

        self.labeled_parameter_perturbations = self.decomposition_wrapper.make_labeled_parameter_perturbations(component_index)

        self.original_model_evaluator = evaluation_contexts.ModelEvaluator(model=original_model)
        self.perturbed_model_evaluator = evaluation_contexts.ModelEvaluator(model=perturbed_model)

        self._name_to_original_parameter = {name: param for name, param in original_model.named_parameters()}
        self._name_to_perturbed_parameter = {name: param for name, param in perturbed_model.named_parameters()}

    @torch.no_grad()
    def _perturb_parameters(self, multiplier: float):
        for param_info, base_perturbation in self.labeled_parameter_perturbations:
            original_param = self._name_to_original_parameter[param_info.name]
            perturbed_param = self._name_to_perturbed_parameter[param_info.name]
            perturbed_param.copy_(original_param + multiplier * base_perturbation)

    @torch.no_grad()
    def _evaluate_for_examples(self, examples: 'evaluation_contexts.Examples') -> 'perturbation_results.ExamplesPerturbationInfo':
        kls = []
        changed_predictions = []

        for example_batch in examples.get_batches(self.evaluation_batch_size):
            original_batch_info = self.original_model_evaluator.compute_batch_info(example_batch)
            perturbed_batch_info = self.perturbed_model_evaluator.compute_batch_info(example_batch)
            batch_info = ExamplesBatchPerturbationInfo(
                original_batch_info=original_batch_info,
                perturbed_batch_info=perturbed_batch_info,
            )
            kls.append(batch_info.kls)
            changed_predictions.append(batch_info.changed_predictions)

        kl = float(torch.mean(torch.cat(kls)).detach().cpu().numpy())
        changed_prediction_fraction = float(torch.mean(torch.cat(changed_predictions).type(torch.float32)).detach().cpu().numpy())

        return perturbation_results.ExamplesPerturbationInfo(
            kl=kl,
            changed_prediction_fraction=changed_prediction_fraction,
        )

    @torch.no_grad()
    def _evaluate_perturbation(self, multiplier: float) -> 'perturbation_results.PerturbationRunInfo':
        self._perturb_parameters(multiplier)

        top_examples_info = self._evaluate_for_examples(self.top_examples)
        baseline_examples_info = self._evaluate_for_examples(self.baseline_examples)
        
        return perturbation_results.PerturbationRunInfo(
            top_examples_info=top_examples_info,
            baseline_examples_info=baseline_examples_info,
        )

    @torch.no_grad()
    def evaluate_perturbation_pm(self, magnitude: float) -> 'perturbation_results.ComponentPerturbationResults':
        """Evaluates a perturbation for both +/- magnitude * normalized_perturbation."""
        assert magnitude >= 0.0

        plus_results = self._evaluate_perturbation(magnitude)
        minus_results = self._evaluate_perturbation(-magnitude)

        return perturbation_results.ComponentPerturbationResults(
            component_index=self.component_index,
            perturbation_magnitude=magnitude,
            plus_results=plus_results,
            minus_results=minus_results,
        )


###############################################################################
###############################################################################


R"""
- only support the LRM-PEF style of NPEFF
- support both the random projected and sparse PEIs
    - Basically have a "down" projected and "up" projected state (maybe have more generic name for this)
        - probably "reduced" vs "full"
    - Need to be able to go from down to up
- Probably support taking rejection from other components.
    - Probably can do this in the "down" projected state for both RP and SP.
    - NEED TO CHECK IF THIS OPERATION APPROXIMATELY COMMUTES WITH THE RANDOM PROJECTION.


- Do I need an evaluation context or something?
    - Might want option whether the logits can/should be materialized all at once or should be compared on the fly.
- Probably should make some python objects wrapping up the NPEFF decompositions
    - Maybe added to the `decomps` folder stuff

"""

# Make some ABC for wrappers around the decomposition with some addition information.
#   The additional information differs on whether it was sparse or random projected PEFs
#   Implement those in different files.


# Do I want to abstract away getting the top examples for each component?
#   - If given coefficients for a set of examples, can just do a top
#   - Might also want to be able to provide explicitly and/or via some seperate method.
#       - I don't think I'll have a version of this to implment.
