"""Sparse versions of diagonal Fisher approximation."""
import dataclasses
from typing import List, Sequence

import h5py
import numpy as np
import tensorflow as tf

from em.fishers import fisher_abcs
from em.fishers import diagonal

from em.util import hdf5_util


@dataclasses.dataclass
class SparseDiagonalFisher(fisher_abcs.FisherAbc):

    fishers: List[tf.sparse.SparseTensor]
    parameters: List[tf.sparse.SparseTensor]

    def __post_init__(self):
        if len(self.fishers) != len(self.parameters):
            raise ValueError('The fisher and parameter lists must have the same length.')

    def _save(self, filepath: str):
        # TODO: Move some of this logic to common code.
        with h5py.File(filepath, "w") as f:
            metadata = f.create_group('metadata')
            metadata.attrs['fisher_class'] = self.__class__.__name__

            data = f.create_group('data')
            data.attrs["n_variables"] = len(self.fishers)

            indices = data.create_group('indices')
            dense_shapes = data.create_group('dense_shapes')

            fisher_values = data.create_group('fisher_values')
            parameter_values = data.create_group('parameter_values')

            # Note that we are assuming that the indices of the sparse fishers and parameters
            # are the same.
            for i, (fisher, parameter) in enumerate(zip(self.fishers, self.parameters)):
                ds = indices.create_dataset(str(i), fisher.indices.shape, dtype=np.int32)
                hdf5_util.set_h5_ds(ds, fisher.indices.numpy().astype(np.int32))

                ds = dense_shapes.create_dataset(str(i), fisher.dense_shape.shape, dtype=np.int32)
                hdf5_util.set_h5_ds(ds, fisher.dense_shape.numpy().astype(np.int32))

                fisher_value = fisher.values.numpy()
                ds = fisher_values.create_dataset(str(i), fisher_value.shape, dtype=fisher_value.dtype)
                hdf5_util.set_h5_ds(ds, fisher_value)

                parameter_value = parameter.values.numpy()
                ds = parameter_values.create_dataset(str(i), parameter_value.shape, dtype=parameter_value.dtype)
                hdf5_util.set_h5_ds(ds, parameter_value)

    @classmethod
    def _load(cls, filepath: str) -> 'SparseDiagonalFisher':
        with h5py.File(filepath, "r") as f:
            indices = f['data/indices']
            dense_shapes = f['data/dense_shapes']
            fisher_values = f['data/fisher_values']
            parameter_values = f['data/parameter_values']

            fishers = []
            parameters = []
            n_variables = f['data'].attrs["n_variables"]
            for i in range(n_variables):
                i = str(i)
                fisher = tf.sparse.SparseTensor(indices[i], fisher_values[i], dense_shapes[i])
                parameter = tf.sparse.SparseTensor(indices[i], parameter_values[i], dense_shapes[i])
                fishers.append(fisher)
                parameters.append(parameter)

            return cls(fishers=fishers, parameters=parameters)


def _compute_top_k_threshold(tensors: Sequence[tf.Tensor], k: int) -> tf.Tensor:
    # Returns a scalar tf.Tensor (i.e. shape = []).

    # TODO: This is fairly poor from a time/space complexity perspective. It
    # can be made a lot better, but I should note that stuff in python will
    # be much slower than the same thing written tensorflow's C++.
    all_params = tf.concat([tf.reshape(f, [-1]) for f in tensors], axis=0)
    values, _ = tf.math.top_k(all_params, k=k, sorted=False)
    return tf.reduce_min(values)


def from_dense_uniformly(
    diagonal_fisher: diagonal.DiagonalFisher,
    model_variables: Sequence[tf.Tensor],
    sparsity: float,
) -> SparseDiagonalFisher:
    # sparsity should be between 0 and 1.
    if not (0 <= sparsity <= 1):
        raise ValueError(f'The `sparsity` parameter must be between 0 and 1. Instead got ${sparsity}.')

    n_params_full = diagonal_fisher.n_parameters
    n_params_sparse = int(sparsity * n_params_full)

    threshold = _compute_top_k_threshold(diagonal_fisher.fishers, n_params_sparse)

    sparse_fishers = []
    sparse_parameters = []
    for f, v in zip(diagonal_fisher.fishers, model_variables):
        # Values equal to 0 will not be copied to the sparse tensor.
        mask = tf.cast(f >= threshold, f.dtype)
        sparse_fishers.append(tf.sparse.from_dense(mask * f))
        sparse_parameters.append(tf.sparse.from_dense(mask * v))

    return SparseDiagonalFisher(fishers=sparse_fishers, parameters=sparse_parameters)


def _compute_kl_contribution_from_reference(
    diagonal_fisher: diagonal.DiagonalFisher,
    model_variables: Sequence[tf.Tensor],
    reference_variables: Sequence[tf.Tensor]
) -> List[tf.Tensor]:
    # Makes use of the information geometric property of the Fisher as a Riemanian metric
    # approximating the KL-divergence given perturbations.
    fishers = diagonal_fisher

    if not (len(fishers) == len(model_variables) == len(reference_variables)):
        raise ValueError('The number of variables and/or Fishers do not match.')

    kl_contributions = []
    for fisher, mvar, rvar in zip(fishers, model_variables, reference_variables):
        sq_delta = (mvar - rvar)**2
        kl_contributions.append(fisher * sq_delta)

    return kl_contributions


def from_dense_by_metric_approximation(
    diagonal_fisher: diagonal.DiagonalFisher, 
    model_variables: Sequence[tf.Tensor],
    reference_variables: Sequence[tf.Tensor],
    sparsity: float
) -> SparseDiagonalFisher:
    # sparsity should be between 0 and 1.
    if not (0 <= sparsity <= 1):
        raise ValueError(f'The `sparsity` parameter must be between 0 and 1. Instead got ${sparsity}.')

    n_params_full = diagonal_fisher.n_parameters
    n_params_sparse = int(sparsity * n_params_full)

    kl_contributions = _compute_kl_contribution_from_reference(
        diagonal_fisher.fishers, model_variables, reference_variables)
    kl_threshold = _compute_top_k_threshold(kl_contributions, k=n_params_sparse)

    sparse_fishers = []
    sparse_parameters = []
    for f, p, kl in zip(diagonal_fisher.fishers, model_variables, kl_contributions):
        # Values equal to 0 will not be copied to the sparse tensor.
        mask = tf.cast(kl >= kl_threshold, f.dtype)
        sparse_fishers.append(tf.sparse.from_dense(mask * f))
        sparse_parameters.append(tf.sparse.from_dense(mask * p))

    return SparseDiagonalFisher(fishers=sparse_fishers, parameters=sparse_parameters)
