"""Regularizers related to Fishers.

# TODO: Support sparse fishers.

"""
import abc
from typing import Callable, Sequence
import tensorflow as tf

###############################################################################


class EwcLoss(tf.keras.losses.Loss):
    def __init__(
        self,
        wrapped: tf.keras.losses.Loss,
        model_variables: Sequence[tf.Tensor],
        ewc_variables: Sequence[tf.Tensor],
        ewc_fishers: Sequence[tf.Tensor],
        lmbda: float,
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert len(model_variables) == len(ewc_variables) == len(ewc_fishers)
        self.wrapped = wrapped
        self.model_variables = model_variables
        self.ewc_variables = ewc_variables
        self.ewc_fishers = ewc_fishers
        self.lmbda = lmbda

    def call(self, y_true, y_pred):
        base_loss = self.wrapped(y_true, y_pred)

        reg_loss = tf.reduce_sum([
            tf.reduce_sum(tf.stop_gradient(ef) * tf.square(v - tf.stop_gradient(ev)))
            for v, ev, ef in zip(self.model_variables, self.ewc_variables, self.ewc_fishers)
        ])

        return base_loss + self.lmbda * reg_loss


###############################################################################


class _BaseAblationLoss(tf.keras.losses.Loss, abc.ABC):
    def __init__(
        self,
        wrapped: tf.keras.losses.Loss,
        model_variables: Sequence[tf.Tensor],
        baseline_variables: Sequence[tf.Tensor],
        baseline_fishers: Sequence[tf.Tensor],
        lmbda: float,
        epsilon: float = 1e-9,
        **kwargs
    ):
        super().__init__(**kwargs)
        assert len(model_variables) == len(baseline_variables) == len(baseline_fishers)
        self.wrapped = wrapped
        self.model_variables = model_variables
        self.baseline_variables = baseline_variables
        self.baseline_fishers = baseline_fishers
        self.lmbda = lmbda
        self.epsilon = epsilon

    @abc.abstractmethod
    def f(self, fisher: tf.Tensor, delta: tf.Tensor):
        # Must return scalar tf.Tensor
        raise NotImplementedError

    @abc.abstractmethod
    def g(self, x: tf.Tensor):
        # The input will be a scalar tf.Tensor. Must return scalar tf.Tensor.
        raise NotImplementedError

    def call(self, y_true, y_pred):
        base_loss = self.wrapped(y_true, y_pred)

        denom = tf.reduce_sum([
            self.f(tf.stop_gradient(ef), v - tf.stop_gradient(ev))
            for v, ev, ef in zip(self.model_variables, self.baseline_variables, self.baseline_fishers)
        ])
        denom = self.g(denom + self.epsilon)

        return base_loss + self.lmbda / denom


class _BaseAblationLossWithStandardF(_BaseAblationLoss):
    def __init__(self, *args, a: float = 1, b: float = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self._a = a
        self._b = b

    def f(self, fisher: tf.Tensor, delta: tf.Tensor):
        f2 = tf.pow(tf.abs(fisher), self._a)
        d2 = tf.pow(tf.abs(delta), self._b)
        return tf.reduce_sum(f2 * d2)

######################################################


class PowerAblationLoss(_BaseAblationLossWithStandardF):
    def __init__(self, *args, c: float = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self._c = c

    def g(self, x: tf.Tensor):
        return tf.pow(x, self._c)


class ExponentialAblationLoss(_BaseAblationLossWithStandardF):
    def __init__(self, *args, c: float = 1, **kwargs):
        super().__init__(*args, **kwargs)
        self._c = c

    def g(self, x: tf.Tensor):
        return tf.exp(self._c * x)


###############################################################################

class _BaseMultiComponentBySubsetAblationLoss(tf.keras.losses.Loss, abc.ABC):
    """
    """
    def __init__(
        self,
        wrapped: tf.keras.losses.Loss,
        model_variables_per_subset: Sequence[Sequence[tf.Tensor]],
        baseline_variables_per_subset: Sequence[Sequence[tf.Tensor]],
        components_per_subset: Sequence[Sequence[Sequence[tf.Tensor]]],
        lmbda: float,
        epsilon: float = 1e-9,
        **kwargs
    ):
        super().__init__(**kwargs)
        assert len(model_variables_per_subset) == len(baseline_variables_per_subset) == len(components_per_subset)
        for mvs, bvs, comps in zip(model_variables_per_subset, baseline_variables_per_subset, components_per_subset):
            assert len(mvs) == len(bvs)
            for cs in comps:
                assert len(cs) == len(mvs)

        self.wrapped = wrapped
        self.model_variables_per_subset = model_variables_per_subset
        self.baseline_variables_per_subset = baseline_variables_per_subset
        self.components_per_subset = components_per_subset
        self.lmbda = lmbda
        self.epsilon = epsilon

    @abc.abstractmethod
    def f(self, fisher: tf.Tensor, delta: tf.Tensor):
        # Must return scalar tf.Tensor
        raise NotImplementedError

    @abc.abstractmethod
    def g(self, x: tf.Tensor):
        # The input will be a scalar tf.Tensor. Must return scalar tf.Tensor.
        raise NotImplementedError

    def call(self, y_true, y_pred):
        base_loss = self.wrapped(y_true, y_pred)

        aux_losses = []
        for mvs, bvs, comps in zip(self.model_variables_per_subset, self.baseline_variables_per_subset, self.components_per_subset):
            assert len(mvs) == len(bvs)
            deltas = [mv - tf.stop_gradient(bv) for mv, bv in zip(mvs, bvs)]
            for cs in comps:
                denom = tf.reduce_sum([
                    self.f(tf.stop_gradient(c), delta)
                    for c, delta in zip(cs, deltas)
                ])
                denom = self.g(denom + self.epsilon)
                aux_losses.append(self.lmbda / denom)

        return base_loss + tf.reduce_sum(aux_losses)


class _BaseMultiComponentBySubsetAblationLossWithStandardF(_BaseMultiComponentBySubsetAblationLoss):

    def __init__(self, *args, a: float = 1, b: float = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self._a = a
        self._b = b

    def f(self, fisher: tf.Tensor, delta: tf.Tensor):
        f2 = tf.pow(tf.abs(fisher), self._a)
        d2 = tf.pow(tf.abs(delta), self._b)
        return tf.reduce_sum(f2 * d2)


######################################################


class MultiComponentBySubsetPowerAblationLoss(_BaseMultiComponentBySubsetAblationLossWithStandardF):
    def __init__(self, *args, c: float = 2, **kwargs):
        super().__init__(*args, **kwargs)
        self._c = c

    def g(self, x: tf.Tensor):
        return tf.pow(x, self._c)


class MultiComponentBySubsetExponentialAblationLoss(_BaseMultiComponentBySubsetAblationLossWithStandardF):
    def __init__(self, *args, c: float = 1, **kwargs):
        super().__init__(*args, **kwargs)
        self._c = c

    def g(self, x: tf.Tensor):
        return tf.exp(self._c * x)
