from nn_compression._core import DeepCabacRdQuantiser


def test_quantiser_correct_idx():
    # Equation for Rate-distortion:
    # RD = Rate * lm + 1/pv * distortion
    delta = 0.1
    min_idx = -5
    max_idx = 5
    lm = 2
    # Setup: There are 5 quantisation levels,
    # [-0.5, -0.4, ... 0.4, 0.5]
    #
    q = DeepCabacRdQuantiser(lm, delta, min_idx, max_idx)
    # bit estimations are
    rate_estimations = [7, 6, 5, 4, 3, 1, 3, 4, 5, 6, 7]
    assert len(rate_estimations) == max_idx - min_idx + 1

    # calculating everything in weight space
    w = 0.23
    points_wspace = [-0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5]
    dists = [(w - x) ** 2 for x in points_wspace]

    pv = 0.3
    rd_objectives = [
        rate_estimations[i] * lm + dists[i] / pv for i in range(len(rate_estimations))
    ]
    w_opt = points_wspace[rd_objectives.index(min(rd_objectives))]

    # repeat in index space
    w_idxspace = w / delta
    points_idxspace = list(range(min_idx, max_idx + 1))
    dists_idxspace = [(w_idxspace - x) ** 2 * delta**2 for x in points_idxspace]
    rd_objectives_idxspace = [
        rate_estimations[i] * lm + dists_idxspace[i] / pv
        for i in range(len(rate_estimations))
    ]
    w_opt_idxspace = (
        points_idxspace[rd_objectives_idxspace.index(min(rd_objectives))] * delta
    )

    assert w_opt == w_opt_idxspace
    # Test c++ implementation against the the implementation above

    w_opt_idxspace_cpp = q.quantize(w, pv) * delta
    assert w_opt_idxspace_cpp == w_opt
