"""Common classes, ABCs, and interfaces for PEF computers."""
import abc
import dataclasses

import h5py
import torch

###############################################################################


@dataclasses.dataclass
class PefComputerInput:
    """Example-specific input to a PEF computer."""

    # These are assumed to have the autograd stuff set-up so we have can compute
    # gradients with respect to them.
    # shape = [n_classes]
    log_probs: torch.Tensor

    # shape = [n_classes_to_compute_for], dtype=int32
    class_indices: torch.Tensor


class PefComputerAbc(abc.ABC):
    """ABC for classes that compute PEFs."""

    @classmethod
    def create(cls, **kwargs) -> 'PefComputerAbc':
        return cls(**kwargs)

    @abc.abstractmethod
    def is_output_projected(self) -> bool:
        """Whether the outputed PEFs are already projected."""
        raise NotImplementedError

    @abc.abstractmethod
    def compute_dense_pef(self, example_info: PefComputerInput) -> torch.Tensor:
        """Computes the PEF for a single example as a dense tensor.

        Returns:
            The representation of the PEF for the example in dense form.
        """
        raise NotImplementedError

    def write_additional_information_to_pefs_file(self, file: h5py.File):
        """Allows sub-classes to save class-specific information to a PEFs file."""
        # Default to noop.
        pass

    def get_n_original_parameters(self) -> int:
        """Returns the number of parameters before any reductions or projections take place."""
        return sum(p.numel() for p in self.parameters)


###############################################################################

# TODO: Maybe move the norm computation stuff elsewhere.


def compute_lrm_pef_frobenius_norm(pef: torch.Tensor) -> torch.Tensor:
    # pef.shape = [rank, n_parameters]
    AtA = torch.einsum('cj,kj->ck', pef, pef)
    sq_norm = torch.einsum('ck,ck->', AtA, AtA)
    return torch.sqrt(sq_norm)
