import pytest
import torch
from torch import Tensor
from margflow.utils.math_utils import (
    parzen_log_prob_spike,
)
from margflow.model_utils import log_prob_diagonal_normal
from margflow.utils.plot_utils import orthonormalize_second


@pytest.mark.parametrize(
    "v, v2, expected",
    (
        (Tensor([3, 0]), Tensor([2, 2]), Tensor([0, 1])),
        (Tensor([2, 0]), Tensor([-2, -2]), Tensor([0, 1])),
    ),
)
def test_orthonormalize_second(v, v2, expected) -> None:
    assert orthonormalize_second(v, v2).abs().allclose(expected)


def test_parzen_log_prob_spike():
    n = 20
    m = 40
    dim = 10
    sigma = 1
    spike_sigma = 2
    cov = torch.ones(dim)
    cov[0] = sigma**2 + spike_sigma**2
    x = torch.randn(n, dim)
    x[:, 0] = x[:, 0] + spike_sigma * torch.randn(n)
    xp = torch.zeros(m, dim)
    spike = torch.zeros(1, dim)
    spike[0, 0] = spike_sigma
    spike = spike.expand(m, dim)
    dists = x[:, None]
    correct_p = log_prob_diagonal_normal(x=dists, diag_cov=cov)
    quick_p = parzen_log_prob_spike(x, xp, spike, sigma, device="cpu")
    assert torch.allclose(correct_p, quick_p)
