#!/usr/bin/env python3
import torch
import pytest
from uimnet.algorithms.soft_labeler import soften, harden

TAUS = (0.7, 0.8, 0.9)

@pytest.mark.parametrize('tau', TAUS)
def test_soften(tau):
  K = 266
  y = torch.nn.functional.one_hot(torch.randint(high=K, size=(1, )))
  s = soften(y, K)
  assert s.max().item() <= 1.0
  assert s.min().item() >= 0.0

@pytest.mark.parametrize('tau', TAUS)
def test_harden(tau):
  K = 266
  y = torch.nn.functional.one_hot(torch.randint(high=K, size=(1, )))
  s = soften(y, K)
  y_ = harden(s, tau)
  assert s.max().item() <= 1.0
  assert s.min().item() >= 0.0
