"""Tests for predictor.py"""

from unittest.mock import patch

import numpy as np

from fmri2music.predictor import BestPredictor, ExportedPredictor, RandomPredictor


def test_random_predictor():
    with patch("fmri2music.emb_loader.get_fma_emb") as mock_get_fma_emb:
        mock_get_fma_emb.return_value = (None, np.random.rand(5, 10))
        predictor = RandomPredictor("mv101", "small")
        emb = predictor.predict_emb("test_key")
        assert emb.shape == (10,)
        mock_get_fma_emb.assert_called_with("small", "mv101")


def test_best_predictor():
    with patch("fmri2music.emb_loader.get_fma_emb") as mock_get_fma_emb, patch(
        "fmri2music.emb_loader.get_gtzan_emb"
    ) as mock_get_gtzan_emb:
        emb_mat = np.asarray([[1.0, 1.0, 0.0], [0.0, -1.0, -1.0]])
        mock_get_fma_emb.return_value = (["a", "b"], emb_mat)
        mock_get_gtzan_emb.return_value = {"c": np.asarray([-1.0, -1.0, -1.0])}
        predictor = BestPredictor("mv108", "small")
        emb = predictor.predict_emb("c")
        assert emb.shape == (3,)
        np.testing.assert_array_equal(emb, emb_mat[1])  # emb_mat[1] is closer than [0]
        mock_get_fma_emb.assert_called_with("small", "mv108")
        mock_get_gtzan_emb.assert_called_with("mv108")


def test_exported_predictor():
    with patch("fmri2music.emb_loader.load_predictions") as mock_load_predictions:
        expected_emb = np.random.rand(10)
        mock_load_predictions.return_value = {"test_key": expected_emb}
        predictor = ExportedPredictor("name", "file_name", "mv101")
        actual_emb = predictor.predict_emb("test_key")
        assert expected_emb.shape == actual_emb.shape
        np.testing.assert_array_equal(expected_emb, actual_emb)
        mock_load_predictions.assert_called_with("file_name")
