import torch
from torch import Tensor

from src.utils.position_transforms import calculate_angle_from_prob


def _transform_to_probs_w_wrap(target: Tensor, num_bins: int, sigma: float) -> Tensor:
    """Transform target value to probability bins but wrap"""
    limit = 2 * torch.pi
    support = torch.linspace(
        -limit, limit, 2 * num_bins + 1, dtype=torch.float32, device=target.device
    )
    _divisor = torch.sqrt(torch.tensor(2.0, dtype=torch.float32)) * sigma
    cdf_evals: Tensor = torch.special.erf((support - target.unsqueeze(-1)) / _divisor)
    z = cdf_evals[..., -1] - cdf_evals[..., 0]
    bin_probs = torch.diff(cdf_evals, n=1, dim=-1)
    double_width = bin_probs / z.unsqueeze(-1)
    perimeter = double_width.shape[-1] // 4
    result = double_width[..., perimeter:-perimeter]
    result[..., -perimeter:] += double_width[..., :perimeter]
    result[..., :perimeter] += double_width[..., -perimeter:]
    return result


def test_angle_logit_to_value():
    """Test converting angle logit to angle value"""
    angles = torch.linspace(-torch.pi, torch.pi, 10, device="cuda")[:-1]
    logits = _transform_to_probs_w_wrap(angles, 20, 0.2)
    # logits = torch.softmax(logits, dim=-1)
    recovered = calculate_angle_from_prob(logits)
    assert torch.allclose(angles, recovered, atol=1e-4, rtol=1e-2)
