import numpy as np
import torch as th
from best3 import prob_grid, prob_grid_rec, bv

device = "cuda"


def to_th(x):
    return th.tensor(x).to(th.float32).to(device)


def make_si(x, *xs):
    xs = (x, *xs)
    N = xs[-1]
    B = xs[:-1]
    mu = np.random.uniform(size=(*B, N))
    si = np.random.uniform(size=(*B, N, N))
    si = si @ np.swapaxes(si, -2, -1) + 1e-3 * np.eye(N)

    return mu, si


def test_bv():
    mu, si = map(to_th, make_si(2))

    lpdf = bv(mu, si)
    dist = th.distributions.MultivariateNormal(mu, si)
    x = dist.rsample((1000,))
    x = x[dist.log_prob(x).exp() > 0.001]
    assert th.allclose(dist.log_prob(x), lpdf(x), atol=1e-4)


def test_prob():
    test_counter = 0
    constraint: list[tuple[int, ...]] = [(0, 1, 2)]
    while test_counter < 10:
        mu, si = make_si(3)
        p = prob_grid(constraint, to_th(mu), to_th(si), N=10_000)
        mc = prob_grid_rec(constraint, to_th(mu), to_th(si), N=10_000)
        if mc < 0.01:
            continue

        print(mc, p)
        assert th.abs(p - mc) < 0.1
        test_counter += 1
