"""LRM-NPEFF decompositions.

Pretty much just wraps what gets written to disk.
"""
import dataclasses
import os
from typing import Optional

import h5py
import numpy as np
import torch

from npeff_torch.util import hdf5_utils

###############################################################################


@dataclasses.dataclass
class LrmNpeffDecomposition:
    """Pretty much just wraps what gets written to disk."""

    # NPEFF component coefficients.
    # shape = [n_examples, n_components]
    W: Optional[torch.Tensor]

    # Psuedo-Fisher vectors.
    # shape = [n_components, n_features]
    G: Optional[torch.Tensor]

    # Indices of parameters in original that are kept in the reduced per-example Fishers. Will
    # only be present for sparse LRM-NPEFF decompositions.
    # shape = [n_features], dtype=int32
    new_to_old_col_indices: Optional[torch.Tensor]

    # Equivalent to full dense size for sparse LRM-NPEFF. For dense LRM-NPEFF,
    # I'm not sure.
    n_parameters: int

    # This is in the saved file, but I'm not loading it at the moment. I'm not sure when
    # it will be set and, if so, what exactly it will be set to.
    # n_classes: int

    def __post_init__(self):
        self.components_are_normalized = False

    @property
    def n_examples(self) -> int:
        if self.W is not None:
            return self.W.shape[0]
        raise ValueError

    @property
    def n_components(self) -> int:
        if self.W is not None:
            return self.W.shape[1]
        if self.G is not None:
            return self.G.shape[0]
        raise ValueError

    #######################################################

    def normalize_reduced_components_to_unit_norm_(self, eps: float = 1e-12):
        """Normalizes the G such that the rank-1 basis PSD matrices have unit Frobenius norm.

        NOTE: This does this for the components in their reduced representations.
        """
        norms = torch.einsum('cf,cf->c', self.G, self.G)[:, None]
        norms = torch.maximum(norms, eps * torch.ones([], dtype=torch.float32, device=self.G.device))
        self.G /= torch.sqrt(norms)
        self.W *= norms.t()
        self.components_are_normalized = True

    #######################################################

    def to(self, device: torch.device) -> 'LrmNpeffDecomposition':
        """Moves the tensors attached to the instance to the device and returns self."""
        if self.W is not None:
            self.W = self.W.to(device)
        if self.G is not None:
            self.G = self.G.to(device)
        if self.new_to_old_col_indices is not None:
            self.new_to_old_col_indices = self.new_to_old_col_indices.to(device)
        return self

    #######################################################

    @classmethod
    def load(cls, filepath: str, *, load_W: bool = True, load_G: bool = True, **kwargs):
        """Load the instance from a file."""
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            W = None
            if 'data/W' in f and load_W:
                W = torch.from_numpy(hdf5_utils.load_h5_ds(f['data/W']))

            G = None
            if 'data/G' in f and load_G:
                G = torch.from_numpy(hdf5_utils.load_h5_ds(f['data/G']))

            new_to_old_col_indices = None
            if 'data/new_to_old_col_indices' in f:
                new_to_old_col_indices = torch.from_numpy(hdf5_utils.load_h5_ds(f['data/new_to_old_col_indices']))

            return cls(
                W=W,
                G=G,
                new_to_old_col_indices=new_to_old_col_indices,
                n_parameters=f['data'].attrs.get('n_parameters', None),
            )
