import numpy as np
import pytest
import torch

from ccvae.nn.loss import MultiScaleRBF


def test_MultiScaleRBF():
    x = torch.Tensor([[1.0, 2.0], [3.0, 5.0]])
    y = torch.Tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
    gammas = [0.07, 8.1]

    mrbf = MultiScaleRBF(gammas)
    actual = mrbf(x, y)
    assert list(actual.shape) == [6]

    dists = np.array([
        ((1.0 - 0.1) ** 2 + (2.0 - 0.2) ** 2)**0.5,
        ((1.0 - 0.3) ** 2 + (2.0 - 0.4) ** 2)**0.5,
        ((1.0 - 0.5) ** 2 + (2.0 - 0.6) ** 2)**0.5,
        ((3.0 - 0.1) ** 2 + (5.0 - 0.2) ** 2)**0.5,
        ((3.0 - 0.3) ** 2 + (5.0 - 0.4) ** 2)**0.5,
        ((3.0 - 0.5) ** 2 + (5.0 - 0.6) ** 2)**0.5,
    ])
    e1 = np.exp(-0.5 * gammas[0] * dists)
    e2 = np.exp(-0.5 * gammas[1] * dists)
    
    expected = (e1 + e2) / 2
    np.testing.assert_almost_equal(actual, expected)
