"""Optimizing the merging coefficients and classifier head via gradient descent."""
import time
from typing import Optional, Sequence

import numpy as np
import tensorflow as tf
from transformers import TFPreTrainedModel

from em.tools.nmf import nmf_common
from em.util import flat_pack

# typdefs
SparseTensor = tf.sparse.SparseTensor


def _compute_n_components(components_by_var: Sequence[SparseTensor]) -> int:
    values = {c.dense_shape[0].numpy() for c in components_by_var}
    if len(values) != 1:
        raise ValueError('Number of components not consistent for components_by_var.')
    return list(values)[0]


def _sq_l2_norm_to_only_first_axis(sparse: tf.sparse.SparseTensor) -> tf.Tensor:
    sq = tf.sparse.map_values(tf.square, sparse)
    axes = list(range(1, len(sparse.dense_shape)))
    return tf.sparse.reduce_sum(sq, axis=axes, output_is_sparse=False)


# TODO: Give this class a better name.
class SparseFitter(tf.keras.Model):
    # TODO: This only supports transfer from a single model to another. Eventually
    # support many to one merges.

    def __init__(
        self,
        components_by_var: Sequence[SparseTensor],
        component_variables: Sequence[tf.Tensor],
        base_variables: Sequence[tf.Tensor],
        output_merge_variables: Sequence[tf.Variable],
        output_train_variables: Sequence[tf.Variable],
        output_model: TFPreTrainedModel,
        base_batch_fisher: Optional[Sequence[tf.Tensor]] = None,
        normalize_fishers: bool = True,
        **kwargs,
    ):
        # TODO: Add option to merge each variable individually?
        super().__init__(**kwargs)

        self.output_model = output_model

        assert len(components_by_var) == len(component_variables) == len(base_variables) == len(output_merge_variables)
        self.n_merge_variables = len(components_by_var)

        self.component_variables = component_variables
        self.base_variables = base_variables
        self.output_merge_variables = output_merge_variables
        self.output_train_variables = output_train_variables
        self.base_batch_fisher = base_batch_fisher

        # Each element in the list should have dense_shape [n_components, *var_shape]
        self.components_by_var = components_by_var

        self.normalize_fishers = normalize_fishers

        self.n_components = _compute_n_components(self.components_by_var)

        self.coeff_logits = self.add_weight(
            name='coeff_logits',
            shape=[1 + self.n_components],
            initializer="random_normal",
            trainable=True)

        if self.normalize_fishers:
            self.component_norms = self._compute_component_norms()

    def _compute_component_norms(self):
        # TODO: Causes segfault.
        var_squares = [
            _sq_l2_norm_to_only_first_axis(c)
            for c in self.components_by_var
        ]
        return tf.sqrt(tf.reduce_sum(var_squares, axis=0))

    def get_merge_coefficients(self) -> tf.Tensor:
        return tf.nn.softmax(self.coeff_logits)

    def _get_merged_variable_iterator(self):
        coeffs = self.get_merge_coefficients()
        base_coeff = coeffs[0]
        comp_coeffs = coeffs[1:]

        if self.normalize_fishers:
            comp_coeffs /= self.component_norms

        # TODO: Get this figured out better.
        # base_fisher_value = 1e-6
        base_fisher_value = 1.0

        for i, _ in enumerate(self.output_merge_variables):
            base_var = self.base_variables[i]
            base_fisher = self.base_batch_fisher[i] if self.base_batch_fisher is not None else base_fisher_value
            comp_vars = self.component_variables[i]
            comp_fishers = self.components_by_var[i]

            reshaped_coeffs = tf.reshape(
                comp_coeffs,
                [self.n_components, *((len(comp_fishers.dense_shape) - 1) * [1])])

            denom = tf.sparse.reduce_sum(comp_fishers * reshaped_coeffs, axis=0)
            nom = denom * comp_vars + base_coeff * base_fisher * base_var
            denom += base_coeff * base_fisher
            # nom = denom * comp_vars + base_coeff * base_fisher_value * base_var
            # denom += base_coeff * base_fisher_value

            # TODO: See if any numerical stability issue happens here.
            yield nom / denom

    def call(self, *args, **kwargs):
        return self.output_model(*args, **kwargs).logits

    # @tf.function
    # def _update_output_model(self):
    #     merged_vars = list(self._get_merged_variable_iterator())
    #     for out_var, update in zip(self.output_merge_variables, merged_vars):
    #         out_var.assign(update)
    #     return merged_vars

    def train_step(self, data):
        x, y = data

        # Update the values of the merged variables.
        merged_vars = list(self._get_merged_variable_iterator())
        for out_var, update in zip(self.output_merge_variables, merged_vars):
            out_var.assign(update)

        grad_vars = [self.output_merge_variables, self.output_train_variables]

        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(grad_vars)
            # prediction = self.output_model(x, training=True)
            prediction = self.output_model(x, training=False)
            logits = prediction.logits
            loss = self.compiled_loss(y, logits)

        merge_grads, train_grads = tape.gradient(loss, grad_vars)

        # # The output_train_variables are trained directly, so update them.
        self.optimizer.apply_gradients(zip(train_grads, self.output_train_variables))

        # Use the chain rule to compute the gradients for the merged variables since
        # we cannot differentiate through the variable.assign method.
        coeff_grads, = tf.gradients(merged_vars, [self.coeff_logits], grad_ys=merge_grads)
        self.optimizer.apply_gradients([[coeff_grads, self.coeff_logits]])

        # all_vars = [self.coeff_logits, *self.output_train_variables]
        # all_grads = [coeff_grads, *train_grads]
        # self.optimizer.apply_gradients(zip(all_vars, all_grads))

        self.compiled_metrics.update_state(y, logits)
        return {m.name: m.result() for m in self.metrics}

###############################################################################


@tf.function(experimental_relax_shapes=True)
def _make_raggeds_for_var(comp: tf.sparse.SparseTensor, comp_var: tf.Tensor, n_components: int):
    all_indices = comp.indices

    component_indices_for_var = []
    component_values_for_var = []
    component_variable_values_for_var = []

    for i in range(n_components):
        mask = all_indices[:, 0] == i
        comp_indices = tf.boolean_mask(all_indices, mask)[:, 1:]
        comp_vals = tf.boolean_mask(comp.values, mask)
        comp_var_vals = tf.gather_nd(comp_var, comp_indices)

        component_indices_for_var.append(comp_indices)
        component_values_for_var.append(comp_vals)
        component_variable_values_for_var.append(comp_var_vals)

    component_indices_for_var2 = tf.ragged.stack(component_indices_for_var)
    component_values_for_var2 = tf.ragged.stack(component_values_for_var)
    component_variable_values_for_var2 = tf.ragged.stack(component_variable_values_for_var)
    return component_indices_for_var2, component_values_for_var2, component_variable_values_for_var2


@tf.function(experimental_relax_shapes=True)
def _make_raggeds_for_comp_var(comp: tf.sparse.SparseTensor, comp_var: tf.Tensor, i: int):
    all_indices = comp.indices
    mask = all_indices[:, 0] == i
    comp_indices = tf.boolean_mask(all_indices, mask)[:, 1:]
    comp_vals = tf.boolean_mask(comp.values, mask)
    comp_var_vals = tf.gather_nd(comp_var, comp_indices)
    return comp_indices, comp_vals, comp_var_vals


# # TODO: Give this class a better name.
# class SparseFitter(tf.keras.Model):
#     # TODO: This only supports transfer from a single model to another. Eventually
#     # support many to one merges.

#     def __init__(
#         self,
#         components_by_var: Sequence[SparseTensor],
#         component_variables: Sequence[tf.Tensor],
#         base_variables: Sequence[tf.Tensor],
#         output_merge_variables: Sequence[tf.Variable],
#         output_train_variables: Sequence[tf.Variable],
#         output_model: TFPreTrainedModel,
#         # maybe optional base batch fisher?
#         normalize_fishers: bool = True,
#         **kwargs,
#     ):
#         # TODO: Add option to merge each variable individually?
#         super().__init__(**kwargs)

#         self.output_model = output_model

#         assert len(components_by_var) == len(component_variables) == len(base_variables) == len(output_merge_variables)
#         self.n_merge_variables = len(components_by_var)

#         self.component_variables = component_variables
#         self.base_variables = base_variables
#         self.output_merge_variables = output_merge_variables
#         self.output_train_variables = output_train_variables

#         # Each element in the list should have dense_shape [n_components, *var_shape]
#         self.components_by_var = components_by_var

#         self.normalize_fishers = normalize_fishers

#         self.n_components = _compute_n_components(self.components_by_var)

#         self.coeff_logits = self.add_weight(
#             name='coeff_logits',
#             shape=[1 + self.n_components],
#             initializer="random_normal",
#             trainable=True)

#         if self.normalize_fishers:
#             self.component_norms = self._compute_component_norms()

#         start = time.time()
#         self._make_ragged_by_vars()
#         print('Make ragged vars time:', time.time() - start)
#         # 32.3

#     def _make_ragged_by_vars(self):
#         self.component_indices_by_var = []
#         self.component_values_by_var = []
#         self.component_variable_values = []

#         for comp, comp_var in zip(self.components_by_var, self.component_variables):
#             all_indices = comp.indices

#             component_indices_for_var = []
#             component_values_for_var = []
#             component_variable_values_for_var = []

#             for i in tf.range(self.n_components):
#                 mask = all_indices[:, 0] == i
#                 comp_indices = tf.boolean_mask(all_indices, mask)[:, 1:]
#                 comp_vals = tf.boolean_mask(comp.values, mask)
#                 comp_var_vals = tf.gather_nd(comp_var, comp_indices)

#                 # (comp_indices, comp_vals, comp_var_vals) = _make_raggeds_for_comp_var(comp, comp_var, i)

#                 component_indices_for_var.append(comp_indices)
#                 component_values_for_var.append(comp_vals)
#                 component_variable_values_for_var.append(comp_var_vals)

#             self.component_indices_by_var.append(tf.ragged.stack(component_indices_for_var))
#             self.component_values_by_var.append(tf.ragged.stack(component_values_for_var))
#             self.component_variable_values.append(tf.ragged.stack(component_variable_values_for_var))
#             # (
#             #     component_indices_for_var,
#             #     component_values_for_var,
#             #     component_variable_values_for_var
#             # ) = _make_raggeds_for_var(comp, comp_var, self.n_components)
#             # self.component_indices_by_var.append(component_indices_for_var)
#             # self.component_values_by_var.append(component_values_for_var)
#             # self.component_variable_values.append(component_variable_values_for_var)

#     def _compute_component_norms(self):
#         # TODO: Causes segfault.
#         var_squares = [
#             _sq_l2_norm_to_only_first_axis(c)
#             for c in self.components_by_var
#         ]
#         return tf.sqrt(tf.reduce_sum(var_squares, axis=0))

#     def get_merge_coefficients(self) -> tf.Tensor:
#         return tf.nn.softmax(self.coeff_logits)

#     @tf.function
#     def _update_output_model(self):
#         coeffs = self.get_merge_coefficients()
#         base_coeff = coeffs[0]
#         comp_coeffs = coeffs[1:]

#         # TODO: Get this figured out better.
#         # base_fisher_value = 1e-6
#         base_fisher_value = 1.0

#         ret = []

#         for i, outvar in enumerate(self.output_merge_variables):
#             base_var = self.base_variables[i]

#             nom = base_coeff * base_fisher_value * tf.ones_like(outvar)
#             denom = base_coeff * base_fisher_value * base_var

#             # for j in tf.range(self.n_components):
#             for j in range(self.n_components):
#                 comp_inds = self.component_indices_by_var[i][j]
#                 if isinstance(comp_inds, tf.RaggedTensor):
#                     comp_inds = comp_inds.to_tensor()

#                 if tf.shape(comp_inds)[0] == 0:
#                     continue

#                 comp_vals = self.component_values_by_var[i][j]
#                 if isinstance(comp_vals, tf.RaggedTensor):
#                     comp_vals = comp_vals.to_tensor()

#                 comp_var_vals = self.component_variable_values[i][j]
#                 if isinstance(comp_var_vals, tf.RaggedTensor):
#                     comp_var_vals = comp_var_vals.to_tensor()

#                 denom = tf.tensor_scatter_nd_add(
#                     denom,
#                     comp_inds,
#                     comp_coeffs[j] * comp_vals
#                 )
#                 nom = tf.tensor_scatter_nd_add(
#                     nom,
#                     comp_inds,
#                     comp_coeffs[j] * comp_vals * comp_var_vals
#                 )
#             # TODO: See if any numerical stability issue happens here.
#             newval = nom / denom
#             outvar.assign(newval)
#             ret.append(newval)

#         return ret

#     def call(self, *args, **kwargs):
#         return self.output_model(*args, **kwargs).logits

#     def train_step(self, data):
#         x, y = data

#         # # Update the values of the merged variables.
#         # merged_vars = list(self._get_merged_variable_iterator())
#         # for out_var, update in zip(self.output_merge_variables, merged_vars):
#         #     out_var.assign(update)

#         # Update the values of the merged variables.
#         merged_vars = self._update_output_model()

#         grad_vars = [self.output_merge_variables, self.output_train_variables]

#         with tf.GradientTape(watch_accessed_variables=False) as tape:
#             tape.watch(grad_vars)
#             # prediction = self.output_model(x, training=True)
#             prediction = self.output_model(x, training=False)
#             logits = prediction.logits
#             loss = self.compiled_loss(y, logits)

#         merge_grads, train_grads = tape.gradient(loss, grad_vars)

#         # # The output_train_variables are trained directly, so update them.
#         self.optimizer.apply_gradients(zip(train_grads, self.output_train_variables))

#         # Use the chain rule to compute the gradients for the merged variables since
#         # we cannot differentiate through the variable.assign method.
#         coeff_grads, = tf.gradients(merged_vars, [self.coeff_logits], grad_ys=merge_grads)
#         self.optimizer.apply_gradients([[coeff_grads, self.coeff_logits]])

#         # all_vars = [self.coeff_logits, *self.output_train_variables]
#         # all_grads = [coeff_grads, *train_grads]
#         # self.optimizer.apply_gradients(zip(all_vars, all_grads))

#         self.compiled_metrics.update_state(y, logits)
#         return {m.name: m.result() for m in self.metrics}

###############################################################################


# def create_components_by_var_from_nmf_decomp(
#     decomp: nmf_common.NmfDecomposition,
#     variables: Sequence[tf.Tensor],
# ):
#     packer = flat_pack.FlatPacker([v.shape for v in variables])
#     assert packer.flat_size == decomp.full_dense_size

#     # H = decomp.H
#     reduced_index_to_og_index = decomp.reduce_kept_indices
