"""Tests for hparam_tuner.py."""

import pathlib
import time

from fmri2music.hparam_tuner import HParamRanges, HParams, HyperparamTuner, hash_hparams


TARGET_METRIC = "metric_name"


def mock_model_fn(params: HParams) -> float:
    """
    Mocked model function, sleeps for a random time between 0 and 2 seconds,
    then returns the negated value of the hyperparameter for testing purposes.
    """
    time.sleep(0.01)
    return {TARGET_METRIC: -float(params["x"])}


def test_hyperparam_tuner(tmp_path: pathlib.Path):
    hparam_dict: HParamRanges = {"x": list(range(0, 10))}
    log_file = tmp_path / "log.csv"

    tuner = HyperparamTuner(hparam_dict, mock_model_fn, TARGET_METRIC, str(log_file))
    best_params_hash, best_score = tuner.grid_search(num_threads=1)

    assert best_params_hash == hash_hparams({"x": 9})
    assert best_score == -9

    assert log_file.is_file()
    with open(log_file, "r", encoding="utf-8") as f:
        lines = f.readlines()
    assert lines[0].strip() == "score,hash,time,x,log"
    assert len(lines) == 11  # 10 hyperparam combinations + header.


def test_hyperparam_tuner_resume(tmp_path: pathlib.Path):
    log_file = tmp_path / "log.csv"

    # First run with hyperparameters from 0 to 1.
    incomplete_hparam_ranges: HParamRanges = {"x": [0, 1]}
    tuner = HyperparamTuner(
        incomplete_hparam_ranges, mock_model_fn, TARGET_METRIC, str(log_file)
    )
    tuner.grid_search(num_threads=2)

    # Simulate resuming a run (here only hparam x=2 will be used).
    hparam_dict: HParamRanges = {"x": [0, 1, 2]}
    tuner = HyperparamTuner(hparam_dict, mock_model_fn, TARGET_METRIC, str(log_file))
    best_hparams_hash, best_score = tuner.grid_search(num_threads=2)

    assert best_hparams_hash == hash_hparams({"x": 2})
    assert best_score == -2

    assert log_file.is_file()
    with open(log_file, "r", encoding="utf-8") as f:
        lines = [l.strip() for l in f.readlines()]
    assert len(lines) == 4  # 3 hyperparam combinations + header.


def test_hash_hparams():
    assert hash_hparams({"x": 1, "y": 2}) == hash_hparams({"y": 2, "x": 1})
    assert not hash_hparams({"x": 1}) == hash_hparams({"x": 2})
