import torch
import torch.nn as nn


fsq_level_book = {
    16: [8, 8, 8, 5, 5, 5],
    14: [8, 8, 8, 6, 5],
    12: [7, 5, 5, 5, 5],
    10: [8, 5, 5, 5],
    8: [8, 6, 5],
    6: [6, 4, 3],
    4: [5, 3],
    2: [2, 2],
}


def round_ste(z):
    """Round with straight through gradients."""
    zhat = torch.round(z)
    return z + (zhat - z).detach()


class FSQ(nn.Module):
    """Quantizer."""

    def __init__(self, levels, eps = 1e-3):
        super().__init__()
        self.eps = eps
        self.register_buffer("levels", torch.tensor(levels))
        self.register_buffer("basis", torch.cat(
                (torch.tensor([1]),
                 torch.cumprod(self.levels[:-1], dim=0))).to(torch.long))
        self.implicit_codebook = self.indices_to_codes(
            torch.arange(self.codebook_size))

    @property
    def num_dimensions(self) -> int:
        """Number of dimensions expected from inputs."""
        return self.levels.shape[0]

    @property
    def codebook_size(self):
        """Size of the codebook."""
        return torch.prod(self.levels).item()

    @property
    def codebook(self):
        """Returns the implicit codebook. Shape (prod(levels), num_dimensions)."""
        return self.implicit_codebook

    def bound(self, z):
        """Bound `z`, an array of shape (..., d)."""
        half_l = (self.levels - 1) * (1 - self.eps) / 2
        offset = torch.where(self.levels % 2 == 1, 0.0, 0.5)
        shift = torch.tan(offset / half_l)
        return torch.tanh(z + shift) * half_l - offset

    def quantize(self, z):
        """Quanitzes z, returns quantized zhat, same shape as z."""
        quantized = round_ste(self.bound(z))

        # Renormalize to [-1, 1].
        half_width = self.levels // 2
        return quantized / half_width

    def forward(self, z):
        return self.quantize(z)

    def _scale_and_shift(self, zhat_normalized):
        # Scale and shift to range [0, ..., L-1]
        half_width = self.levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat):
        half_width = self.levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat):
        """Converts a `code` to an index in the codebook."""
        assert zhat.shape[-1] == self.num_dimensions
        zhat = self._scale_and_shift(zhat)
        return (zhat * self.basis).sum(axis=-1).to(torch.long)

    def indices_to_codes(self, indices):
        """Inverse of `indices_to_codes`."""
        indices = indices.unsqueeze(-1)
        codes_non_centered = torch.remainder(
            torch.floor_divide(indices, self.basis), self.levels
        )
        return self._scale_and_shift_inverse(codes_non_centered)


if __name__ == "__main__":
    fsq = FSQ(levels=[3, 5, 4])
    z = torch.tensor([0.25, 0.6, -7])
    zhat = fsq.quantize(z)
    print(f"Quantized {z} -> {zhat}") # Quantized [ 0.25  0.6  -7.  ] -> [ 0.   0.5 -1. ]

    # We can map to an index in the codebook.
    idx = fsq.codes_to_indices(zhat)
    print(f"Code {zhat} is the {idx}-th index.") # Code [ 0.   0.5 -1. ] is the 10-th index.

    # Back to code
    code_out = fsq.indices_to_codes(idx)
    print(f"Index {idx} mapped back to {zhat}.") # Index 10 mapped back to [ 0.   0.5 -1. ].

    fsq_small = FSQ(levels=[3, 4])
    print("Codebook for small FSQ")
    print(fsq_small.codebook)
