import torch

import acquire


z1 = torch.FloatTensor([
    [0.4, 0.4],
    [1.0, 0.5],
    [-0.5, -0.5],
    [-0.25, 0],
    [0, 0]
])

z2 = torch.FloatTensor([
    [0.4, 0.4],
    [1.0, 0.5],
    [-0.5, -0.5]
])


def test_compute_min_deltas():
    eps = 1e-16

    (
        U,
        L,
        D,
        min_deltas,
        min_deltas_indices,
        indices
    ) = acquire.GreedyCoreset.compute_min_deltas(
        z1, z2, batch_size=2, device="cpu"
    )

    assert U == 5
    assert L == 3
    assert len(min_deltas.shape) == 1
    assert len(min_deltas) == U
    assert len(indices.shape) == 1
    assert len(indices) == U

    assert (min_deltas_indices == torch.LongTensor([
        0, 1, 2, 2, 0
    ])).all()

    assert (min_deltas[:L] == 0).all()
    assert (min_deltas[L] - (z1[L]-z2[2]).norm()).abs() < eps
    assert (min_deltas[L+1] - (z1[L+1]-z2[0]).norm()).abs() < eps


def test_score_zs():
    gcoreset = acquire.GreedyCoreset(budget=1, batch_size=2, device="cpu")
    mask = gcoreset.score_zs(z1, z2)
    assert (mask == torch.FloatTensor([0, 0, 0, 0, 1])).all()
    gcoreset2 = acquire.GreedyCoreset(budget=2, batch_size=2, device="cpu")
    mask = gcoreset2.score_zs(z1, z2)
    assert (mask == torch.FloatTensor([0, 0, 0, 1, 1])).all()