import torch
from abc import ABC, abstractmethod
from data_utils.statistics import entropy_counts
from data_utils.arrays import make_numpy
import numpy as np


class Grid(ABC):
    @property
    @abstractmethod
    def gridpoints(self) -> torch.Tensor:
        """Returns all the gridpoints of the grid, sorted from lowest to highest in a
        1-dimensional tensor."""
        raise NotImplementedError

    @property
    def numel(self):
        return self.gridpoints.numel()

    @property
    @abstractmethod
    def dimension(self) -> int:
        """Returns the dimension of the tensors in the rated (for multi-dimensional
        grid points)."""
        raise NotImplementedError


def _calculate_grid_overhead(nlevels: int, regular_grid: bool):
    if regular_grid:
        gridpoint_overhead_estimate = 5
    else:
        gridpoint_overhead_estimate = 16 + 5
    if regular_grid:
        if nlevels > 1:
            return nlevels * gridpoint_overhead_estimate + 32
        else:
            return 16
    else:
        return nlevels * gridpoint_overhead_estimate


def rtn_grid(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
    return grid[rtn_grid_idx(x, grid)]


def rtn_grid_idx(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
    grid = grid.to(x.device)
    # first transform the grid into bins
    grid_midway = (grid[1:] + grid[:-1]) / 2
    new_grid = torch.cat(
        [torch.tensor(-torch.inf, device=grid.device).reshape(1), grid_midway.flatten()]
    )
    idx = torch.bucketize(x, new_grid)
    assert torch.all(idx > 0)  # the first bin is -inf
    return idx - 1


class RatedGrid(Grid):
    """Grid with counts for each gridpoint. This is useful for entropy calculations and
    to estimate the probability distribution of the gridpoints."""

    def __init__(self, grid: Grid, counts: torch.Tensor) -> None:
        """Initialises the RatedGrid with a grid and counts for each gridpoint. The counts
        array must be the same length as the gridpoints array in the grid."""
        super().__init__()
        self._grid = grid
        self._counts = counts
        assert counts.shape[0] == grid.gridpoints.shape[0]

    @property
    def dimension(self) -> int:
        return self._grid.dimension

    def update_counts_with_tensor(self, x: torch.Tensor) -> "RatedGrid":
        """Returns a new RatedGrid with updated counts based on the tensor x. The occurence
        of each gridpoint in x is counted and stored in the new RatedGrid."""
        if self._grid.dimension > 1:
            raise NotImplementedError(
                "Only 1-dimensional grids are supported at the moment."
            )
        idx = rtn_grid_idx(x, self._grid.gridpoints)
        idx, counts = torch.unique(idx, return_counts=True, sorted=True)
        # patch up with zeros, as some grid points may have 0 counts
        new_counts = torch.zeros_like(
            self._grid.gridpoints, device=self._grid.gridpoints.device
        )
        new_counts[idx] += counts
        return RatedGrid(self._grid, new_counts)

    @staticmethod
    def from_tensor(x: torch.Tensor):
        """Returns a RatedGrid with gridpoints from the tensor x and counts for each gridpoint."""
        levels, counts = torch.unique(x, return_counts=True)
        return RatedGrid(FixedValueScalarGrid(levels.flatten()), counts)

    @staticmethod
    def from_vectors(x: torch.Tensor):
        """Returns a RatedGrid with gridpoints from the tensor x and counts for each gridpoint. X
        may be vector valued, where each row is treated as a gridpoint."""
        levels, counts = torch.unique(x, return_counts=True, dim=0)
        return RatedGrid(CodebookGrid(levels), counts)

    @property
    def gridpoints(self) -> torch.Tensor:
        """Returns all the gridpoints of the grid."""
        return self._grid.gridpoints

    @property
    def counts(self) -> torch.Tensor:
        """Returns the counts of the gridpoints. There has to be one count for each gridpoint."""
        return self._counts

    def entropy(self, add_overhead: bool = False) -> float:
        """Returns the entropy of the grid, optionally including the overhead of storing the grid and the counts.
        Returns the entropy in bits per element."""
        isregular = isinstance(self._grid, EquidistantGrid)

        c = self.counts
        total_numel = int(c.sum().item())
        total_bits = entropy_counts(c) * total_numel

        if add_overhead:
            nlevels = c[c > 0].numel()
            overhead = _calculate_grid_overhead(nlevels, isregular)
            total_bits += overhead
        return total_bits / total_numel

    def __repr__(self):
        return str(self)

    def __str__(self) -> str:
        return f"RatedGrid(grid={self.gridpoints}, counts={self.counts})"


class EquidistantGrid(Grid):
    """Equidistant grid with a start value, a step size and a number of points.

    For example, if start=0, step=0.1 and npoints=10, the gridpoints will be [0.0, 0.1, 0.2, ..., 0.9].
    This grid allows for a lower entropy when compressing, as the overhead is lower
    (no codebook has to be sent).
    """

    def __init__(self, start: float, step: float, npoints: int) -> None:
        super().__init__()
        self.start = start
        self.step = step
        self.npoints = npoints
        # ensure floating point inconsistencies are not happening with the
        # end comparison
        eps = 1e-12
        if step < eps:
            raise ValueError("Step size under floating point precision!")
        self._gridpoints = torch.arange(
            self.start, self.start + self.step * self.npoints - eps, self.step
        )

    @property
    def dimension(self) -> int:
        return 1

    @property
    def gridpoints(self) -> torch.Tensor:
        return self._gridpoints


class CodebookGrid(Grid):
    """Grid with a codebook of gridpoints. This is for example used in
    vector quantisation, where each row is treated as a gridpoint.

    Conceptually similar to the FixedValueGrid, but using Vectors instead of scalars."""

    def __init__(self, values: torch.Tensor) -> None:
        super().__init__()
        self.values = values

    @property
    def dimension(self) -> int:
        return self.values.shape[1]

    @property
    def gridpoints(self) -> torch.Tensor:
        return self.values


class FixedValueScalarGrid(Grid):
    """Grid with fixed values as gridpoints. The values can also be tensors, in which
    case each row is treated as a gridpoint."""

    def __init__(self, values: torch.Tensor) -> None:
        """Initialises the grid with the given values as gridpoints. The values
        are sorted in ascending order."""
        super().__init__()
        self.values = values.flatten().sort()[0]

    @property
    def dimension(self) -> int:
        return 1

    @property
    def gridpoints(self) -> torch.Tensor:
        return self.values
