from dataclasses import dataclass
from typing import Optional
import torch
from abc import ABC, abstractmethod
import torch.nn as nn


class EntropyModel(ABC):
    @abstractmethod
    def estimate(self) -> float:
        raise NotImplementedError

    @abstractmethod
    def update(self, x: torch.Tensor):
        raise NotImplementedError


@dataclass
class Quantised:
    x: torch.Tensor
    grid: None

    def entropy(self, add_overhead: bool = False):
        raise NotImplementedError


class Quantiser(ABC, nn.Module):
    """Basic Quantiser interface. All quantisers should inherit from this class.

    Quantisers should be used in the following way:
    1. Instantiate Quantiser with number of bits used, and eventually other parameters.
    2. Call find_params with the tensor to be quantised to determine the quantisation parameters.
    3. Call quantise to quantise the tensor. ALternatively, call quantise_with_uncertainty and
        pass a posterior_variance, which should have the same shape as the original tensor.
        The posterior_variance might for example be a tensor containing 1/H_{ij}. The larger the
        variance, the less important the parameter is, and the more it can be quantised.

    >>> import torch
    >>> x = torch.randn(10, 10)
    >>> q = Quantiser(bits=8)
    >>> q.find_params(x)
    >>> x_hat = q.quantise().x
    """

    def __init__(self, bits) -> None:
        super().__init__()
        self._bits = bits

    @property
    def bits(self):
        return self._bits

    def find_params(self, x: torch.Tensor, weight=False) -> None:
        """Finds the parameters of the quantiser. This should be called before quantising any tensor."""

    def quantise_with_uncertainty(
        self,
        x: torch.Tensor,
        posterior_variance: torch.Tensor,
        col: Optional[int] = None,
    ) -> Quantised:
        """Quantises tensor x. Quantiser might use the passed
        posterior variance to determine quantisation decisions.

        The posterior variance might for example be a tensor containing
        1/H_{ij}. Must have the same shape as x.

        Default implementation ignores the posterior variance.
        """
        return self.quantise(x, col=col)

    @abstractmethod
    def quantise(self, x: torch.Tensor, col: Optional[int] = None) -> Quantised:
        raise NotImplementedError

    def ready(self) -> bool:
        return True

    def __call__(self, x: torch.Tensor) -> Quantised:
        return self.quantise(x)
