from nn_compression.quantisation import (
    FixedValueScalarGrid,
    EquidistantGrid,
    RatedGrid,
    Grid,
)
import torch
import math


def test_equidistant_grid_points():
    grid = EquidistantGrid(0, 0.1, 10)
    assert torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]).allclose(
        grid.gridpoints
    )


def test_rated_grid_entropy():
    grid = FixedValueScalarGrid(torch.tensor([0, 1]))
    counts = torch.tensor([10, 10])
    rated_grid = RatedGrid(grid, counts)
    assert rated_grid.entropy() == 1.0

    counts = torch.tensor([10, 20])
    rated_grid = RatedGrid(grid, counts)
    assert rated_grid.entropy() == -1 / 3 * math.log2(1 / 3) - 2 / 3 * math.log2(2 / 3)


def test_rated_grid_update_counts():
    grid = FixedValueScalarGrid(torch.tensor([0, 1]))
    counts = torch.tensor([10, 10])
    rated_grid = RatedGrid(grid, counts)
    new_counts = torch.tensor([1, 2])
    new_tensor = torch.tensor([1, 0, 1])
    new_grid = rated_grid.update_counts_with_tensor(new_tensor)
    new_grid.counts.allclose(new_counts)
