import numpy as np

from offline_rl.rewards.evaluation.distance_matrix import DistanceMatrix


class TestDistanceMatrix:
    def test_save_load(self, tmp_path):
        labels = ["1", "2", "3"]
        distances = np.eye(3)
        distance_matrix = DistanceMatrix(labels, distances)
        filepath = str(tmp_path / "mat.pkl")
        distance_matrix.save(filepath)

        loaded = DistanceMatrix.load(filepath)
        assert loaded.labels == labels
        np.testing.assert_array_almost_equal(loaded.distances, distances)
