import numpy as np
import torch

from offline_rl.rewards.evaluation.distances import (
    compute_pearson_correlation,
    compute_distance_between_reward_pairs,
)
from offline_rl.rewards.evaluation.reward_collection import RewardCollection


def test_compute_distance_between_reward_pairs():
    rewards = RewardCollection({
        "a": torch.arange(10).to(float),
        "b": torch.arange(10).to(float),
        "c": -torch.arange(10).to(float),
    })
    distance_matrix = compute_distance_between_reward_pairs(rewards, compute_pearson_correlation)
    assert list(sorted(distance_matrix.labels)) == ["a", "b", "c"]
    expected = np.array([
        [1., 1, -1],
        [1, 1, -1],
        [-1, -1, 1],
    ])
    np.testing.assert_array_almost_equal(expected, distance_matrix.distances)
