"""Information about what layers, variables, etc. component H's are concentrated on."""
import dataclasses
from typing import List

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from em.util import flat_pack


###############################################################################


@dataclasses.dataclass
class ParameterInfo:
    """Information about a single variable."""

    name: str

    # Sum of the values of each component's h over this parameter. Has length
    # equal to the number of components.
    values: List[float]

    # The shape of this variable.
    shape: List[int]

    def to_json_obj(self):
        return {
            'name': self.name,
            'values': [float(v) for v in self.values],
            'shape': [int(s) for s in self.shape],
        }


###############################################################################


# def make_parameter_infos(model: tf.keras.Model, H_values: np.ndarray, H_indices: np.ndarray) -> List[ParameterInfo]:
#     variables = model.trainable_variables
#     packer = flat_pack.FlatPacker([v.shape for v in variables])

#     H_by_params = packer.decode_sparse_tf(H_values, H_indices)

#     infos = []
#     for v, H in zip(variables, H_by_params):
#         # values = tf.sparse.reduce_sum(H, axis=list(range(1, len(H.shape))), output_is_sparse=False)
#         info = ParameterInfo(
#             # TODO: Some processing of the name?
#             name=v.name,
#             values=values,
#             shape=list(v.shape),
#         )
#         infos.append(info)

#     return infos


def make_parameter_infos(model: tf.keras.Model, H: List[tf.sparse.SparseTensor]) -> List[ParameterInfo]:
    variables = model.trainable_variables
    packer = flat_pack.FlatPacker([v.shape for v in variables])

    values_by_params = []
    for h in tqdm(H):
        h_by_params = packer.decode_sparse_tf(h.values[None, ...], h.indices)
        values_by_params.append([
            tf.reduce_sum(hv.values).numpy()
            for hv in h_by_params
        ])

    values_by_params = [list(i) for i in zip(*values_by_params)]
    assert len(values_by_params) == len(variables)

    infos = []
    for v, values in zip(variables, values_by_params):
        info = ParameterInfo(
            # TODO: Some processing of the name?
            name=v.name,
            values=values,
            shape=list(v.shape),
        )
        infos.append(info)

    return infos
