"""Regularized by distance to subspace spanned by components.

Right now just supporting doing this for a single component.
"""
from typing import List, Optional

import numpy as np
import tensorflow as tf

from em.util import flat_pack


def find_unnormalized_reduced_perturbation(G: np.ndarray, component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    g_main = np.copy(G[component_index])

    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]

    return g_main


class SingleComponentRegularized(tf.keras.Model):

    def __init__(
        self,
        model: tf.keras.Model,
        component_g: np.ndarray,
        new_to_old_col_indices: np.ndarray,
        lmbda_ss: float,
        dataset_fisher: Optional[List[tf.Tensor]] = None,
        lmbda_ewc: float = 0.0,
        lmbda_iso: float = 0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model = model
        self.og_variables = [tf.identity(v) for v in model.trainable_variables]

        self.packer = flat_pack.FlatPacker([v.shape for v in self.og_variables])

        self._dataset_fisher = dataset_fisher

        self._lmbda_ss = lmbda_ss
        self._lmbda_ewc = lmbda_ewc
        self._lmbda_iso = lmbda_iso

        if lmbda_ewc > 0.0:
            assert self._dataset_fisher is not None

        self._precompute_stuff(component_g, new_to_old_col_indices)

    def _precompute_stuff(self, component_g: np.ndarray, new_to_old_col_indices: np.ndarray):
        # Normalize the component.
        component_g = component_g / np.sqrt(np.sum(component_g**2))

        self.per_var_new_to_old_indices = self.packer.convert_global_indices_to_flat_per_tensor_indices(
            new_to_old_col_indices)

        self.g_by_var = []
        i = 0
        for inds in self.per_var_new_to_old_indices:
            size = len(inds)
            self.g_by_var.append(tf.cast(component_g[i : i + size], tf.float32))
            i += size

    def call(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    @tf.function
    def _get_sq_distance_to_subspace(self, deltas: List[tf.Tensor]) -> tf.Tensor:
        dot_prod_contrs = []
        dist_contrs = []

        for delta, g, inds in zip(deltas, self.g_by_var, self.per_var_new_to_old_indices):
            delta = tf.reshape(delta, [-1])
            delta_reduced = tf.gather(delta, inds)
            dist_contrs.append(tf.einsum('i,i->', delta_reduced, g))
            dot_prod_contrs.append(tf.reduce_sum(tf.square(delta)) - tf.reduce_sum(tf.square(delta_reduced)))

        dot_prod = tf.reduce_sum(dot_prod_contrs)

        for delta, g, inds in zip(deltas, self.g_by_var, self.per_var_new_to_old_indices):
            delta = tf.reshape(delta, [-1])
            delta_reduced = tf.gather(delta, inds)
            dist_contrs.append(tf.reduce_sum(tf.square(delta_reduced - dot_prod * g)))

        return tf.reduce_sum(dist_contrs)

    @tf.function
    def _compute_ewc_loss(self, deltas: List[tf.Tensor]) -> tf.Tensor:
        if self._lmbda_ewc <= 0.0:
            return tf.cast(0.0, tf.float32)

        contrs = []
        assert len(self._dataset_fisher) == len(deltas)
        for delta, f in zip(deltas, self._dataset_fisher):
            contrs.append(tf.reduce_sum(tf.square(delta) * f))

        return tf.reduce_sum(contrs)

    @tf.function
    def _compute_iso_loss(self, deltas: List[tf.Tensor]) -> tf.Tensor:
        if self._lmbda_iso <= 0.0:
            return tf.cast(0.0, tf.float32)
        return tf.reduce_sum([tf.reduce_sum(tf.square(delta)) for delta in deltas])

    @tf.function
    def _compute_additional_losses(self, variables: List[tf.Tensor]) -> tf.Tensor:
        deltas = [v - og_v for v, og_v in zip(variables, self.og_variables)]
        losses = [
            self._lmbda_ss * self._get_sq_distance_to_subspace(deltas),
            self._lmbda_ewc * self._compute_ewc_loss(deltas),
            self._lmbda_iso * self._compute_iso_loss(deltas),
        ]
        return tf.reduce_sum(losses)

    def train_step(self, data):
        x, y = data

        variables = self.model.trainable_variables

        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)
            loss = self.compiled_loss(y, y_pred)
            loss += self._compute_additional_losses(variables)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics.
        # Metrics are configured in `compile()`.
        self.compiled_metrics.update_state(y, y_pred)

        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

    def perturb_weights_by_component(self, multiplier: float):
        variables = self.model.trainable_variables
        for v, g, inds in zip(variables, self.g_by_var, self.per_var_new_to_old_indices):
            # v = tf.reshape(v, [-1])
            delta = np.zeros([tf.size(v).numpy()], dtype=np.float32)
            delta[inds.numpy()] = multiplier * g.numpy()
            v.assign_add(tf.reshape(delta, v.shape))
