"""Common stuff for NMF."""
import dataclasses
import os
from typing import List, Optional

from absl import flags
import h5py
import numpy as np
import scipy.sparse as sps
import tensorflow as tf

from em.util import hdf5_util

FLAGS = flags.FLAGS


@dataclasses.dataclass
class NmfDecomposition:
    """Encapsulates the results of an NMF decomposition."""

    # shape = [n_examples, n_components]
    W: np.ndarray
    # shape = [n_components, n_features]
    H: np.ndarray

    # shape = [n_features], dtype=np.int32
    # Indices of parameters in original that are kept in the reduced
    # per-example Fishers.
    reduce_kept_indices: Optional[np.ndarray] = None

    full_dense_size: Optional[int] = None

    def save(self, filepath: str):
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "w") as f:
            data = f.create_group('data')

            W_ds = data.create_dataset('W', self.W.shape, dtype=np.float32)
            hdf5_util.set_h5_ds(W_ds, self.W)

            H_ds = data.create_dataset('H', self.H.shape, dtype=np.float32)
            hdf5_util.set_h5_ds(H_ds, self.H)

            if self.reduce_kept_indices is not None:
                reduce_kept_indices_ds = data.create_dataset(
                    'reduce_kept_indices', self.reduce_kept_indices.shape, dtype=np.int32)
                hdf5_util.set_h5_ds(reduce_kept_indices_ds, self.reduce_kept_indices)
                reduce_kept_indices_ds.attrs['full_dense_size'] = self.full_dense_size

    @classmethod
    def load(cls, filepath: str) -> 'NmfDecomposition':
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "r") as f:
            W_ds = f['data/W']
            W = np.zeros(W_ds.shape, dtype=np.float32)
            W_ds.read_direct(W)

            H_ds = f['data/H']
            H = np.zeros(H_ds.shape, dtype=np.float32)
            H_ds.read_direct(H)

            reduce_kept_indices = None
            full_dense_size = None
            if 'data/reduce_kept_indices' in f:
                reduce_kept_indices_ds = f['data/reduce_kept_indices']
                reduce_kept_indices = np.zeros(reduce_kept_indices_ds.shape, dtype=np.int32)
                reduce_kept_indices_ds.read_direct(reduce_kept_indices)
                full_dense_size = reduce_kept_indices_ds.attrs['full_dense_size']

        if isinstance(full_dense_size, np.ndarray):
            assert len(full_dense_size) == 1
            full_dense_size = full_dense_size[0]

        return cls(W=W, H=H, reduce_kept_indices=reduce_kept_indices, full_dense_size=full_dense_size)

    def normalize_components_to_unit_norm(self, eps=1e-12):
        # NOTE: This will modify W, H in place!
        norms = np.sqrt(np.sum(self.H**2, axis=-1, keepdims=True))
        # norms = tf.math.sqrt(tf.reduce_sum(self.H**2, axis=-1, keepdims=True)).numpy()
        self.H /= norms + eps
        self.W *= norms.T + eps

    def get_full_H(self) -> np.ndarray:
        if self.reduce_kept_indices is None:
            return self.H

        assert self.full_dense_size is not None

        full_H = np.zeros([self.H.shape[0], self.full_dense_size], dtype=np.float32)
        for i in range(self.H.shape[0]):
            full_H[i, self.reduce_kept_indices] = self.H[i]

        return full_H

    def get_full_sparse_H(self) -> List[tf.sparse.SparseTensor]:
        assert self.reduce_kept_indices is not None
        assert self.full_dense_size is not None

        hs = []
        for h in self.H:
            sparse_h = tf.sparse.SparseTensor(
                self.reduce_kept_indices[:, None],
                h,
                [self.full_dense_size]
            )
            hs.append(sparse_h)

        return hs


@dataclasses.dataclass
class SparseNmfDecomposition:
    """Encapsulates the results of an NMF decomposition."""

    # shape = [n_examples, n_components]
    W: np.ndarray

    # shape = [n_components, n_features], dtype=np.int32
    H_shape: np.ndarray
    # shape = [nnz], dtype=float32
    H_values: np.ndarray
    # shape = [n_examples + 1], dtype=int64
    H_row_indices: np.ndarray
    # shape = [nnz], dtype=int32
    H_column_indices: np.ndarray

    # shape = [n_features], dtype=np.int32
    # Indices of parameters in original that are kept in the reduced
    # per-example Fishers.
    reduce_kept_indices: Optional[np.ndarray] = None

    full_dense_size: int = None

    def get_full_sparse_H(self) -> List[tf.sparse.SparseTensor]:
        # Do `sparse_util.stack_as_rows(self.get_full_sparse_H())` to get
        # this as a single sparse tensor.
        #
        # TODO: If I do the above a lot, it'll probably be more efficient to
        # write a method that does it directly.
        return [self.get_single_sparse_h(i) for i in range(self.H_shape[0])]

    def get_single_sparse_h(self, component_index: int) -> tf.sparse.SparseTensor:
        assert self.reduce_kept_indices is not None
        assert self.full_dense_size is not None

        start = self.H_row_indices[component_index]
        end = self.H_row_indices[component_index + 1]

        values = self.H_values[start:end]
        col_inds = self.H_column_indices[start:end]
        og_col_inds = self.reduce_kept_indices[col_inds]

        return tf.sparse.SparseTensor(
            og_col_inds[:, None],
            values,
            [self.full_dense_size]
        )

    def normalize_components_to_unit_norm(self, eps=1e-12):
        # NOTE: This will modify W, H in place!
        #
        # NOTE: Either this or some other part of my code is wrong!!!
        for i in range(self.H_shape[0]):
            start = self.H_row_indices[i]
            end = self.H_row_indices[i + 1]

            values = self.H_values[start:end]
            norm = np.sqrt((values**2).sum())
            self.H_values[start:end] /= norm + eps
            self.W[:, i] *= norm

    def save(self, filepath: str):
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "w") as f:
            data = f.create_group('data')

            W_ds = data.create_dataset('W', self.W.shape, dtype=np.float32)
            hdf5_util.set_h5_ds(W_ds, self.W)

            # H

            H_group = data.create_group('H')
            H_group.attrs['shape'] = self.H_shape

            H_values_ds = H_group.create_dataset('values', self.H_values.shape, dtype=np.float32)
            hdf5_util.set_h5_ds(H_values_ds, self.H_values)

            H_row_indices_ds = H_group.create_dataset('row_indices', self.H_row_indices.shape, dtype=np.int32)
            hdf5_util.set_h5_ds(H_row_indices_ds, self.H_row_indices)

            H_column_indices_ds = H_group.create_dataset('column_indices', self.H_column_indices.shape, dtype=np.int32)
            hdf5_util.set_h5_ds(H_column_indices_ds, self.H_column_indices)

            if self.reduce_kept_indices is not None:
                reduce_kept_indices_ds = data.create_dataset(
                    'reduce_kept_indices', self.reduce_kept_indices.shape, dtype=np.int32)
                hdf5_util.set_h5_ds(reduce_kept_indices_ds, self.reduce_kept_indices)
                reduce_kept_indices_ds.attrs['full_dense_size'] = self.full_dense_size

    @classmethod
    def load(cls, filepath: str) -> 'SparseNmfDecomposition':
        filepath = os.path.expanduser(filepath)
        with h5py.File(filepath, "r") as f:
            W_ds = f['data/W']
            W = np.zeros(W_ds.shape, dtype=np.float32)
            W_ds.read_direct(W)

            H_group = f['data/H']
            H_shape = H_group.attrs['shape']

            # NOTE: TEMPORARY DUE TO BUG WHEN WRITING!!!!!
            if 'values' not in H_group:
                H_group = f['data']

            H_values_ds = H_group['values']
            H_values = np.zeros(H_values_ds.shape, dtype=np.float32)
            H_values_ds.read_direct(H_values)

            H_row_indices_ds = H_group['row_indices']
            H_row_indices = np.zeros(H_row_indices_ds.shape, dtype=np.int32)
            H_row_indices_ds.read_direct(H_row_indices)

            H_column_indices_ds = H_group['column_indices']
            H_column_indices = np.zeros(H_column_indices_ds.shape, dtype=np.int32)
            H_column_indices_ds.read_direct(H_column_indices)

            reduce_kept_indices = None
            full_dense_size = None
            if 'data/reduce_kept_indices' in f:
                reduce_kept_indices_ds = f['data/reduce_kept_indices']
                reduce_kept_indices = np.zeros(reduce_kept_indices_ds.shape, dtype=np.int32)
                reduce_kept_indices_ds.read_direct(reduce_kept_indices)
                full_dense_size = reduce_kept_indices_ds.attrs['full_dense_size']

        if isinstance(full_dense_size, np.ndarray):
            assert len(full_dense_size) == 1
            full_dense_size = full_dense_size[0]

        return cls(
            W=W,
            H_shape=H_shape,
            H_values=H_values,
            H_row_indices=H_row_indices,
            H_column_indices=H_column_indices,
            reduce_kept_indices=reduce_kept_indices,
            full_dense_size=full_dense_size,
        )


###############################################################################


_DEFAULT_FLAG_PREFIX = 'nmf'


def add_nmf_flags(prefix: str = _DEFAULT_FLAG_PREFIX):
    flags.DEFINE_integer(f"{prefix}_n_components", None, '')

    flags.DEFINE_integer(f"{prefix}_max_iter", 200, '')
    flags.DEFINE_float(f"{prefix}_tol", 1e-6, '')

    flags.DEFINE_float(f"{prefix}_alpha", 0.0, '')
    flags.DEFINE_float(f"{prefix}_beta", 1.0, '')
    flags.DEFINE_float(f"{prefix}_l1_ratio", 0.0, '')


def _get_flag(prefix: str, name: str):
    return getattr(FLAGS, f'{prefix}_{name}')


def get_nmf_init_kwargs_from_flags(prefix: str = _DEFAULT_FLAG_PREFIX):
    return {
        "rank": _get_flag(prefix, 'n_components'),
    }


def get_nmf_fit_kwargs_from_flags(prefix: str = _DEFAULT_FLAG_PREFIX):
    return {
        "max_iter": _get_flag(prefix, "max_iter"),
        "tol": _get_flag(prefix, "tol"),
        "alpha": _get_flag(prefix, "alpha"),
        "beta": _get_flag(prefix, "beta"),
        "l1_ratio": _get_flag(prefix, "l1_ratio"),
    }
