from pathlib import Path

import pymc as pm

from llm_mcts.mcts_algo.pymc_mixed.pymc_interface import Observation, PyMCInterface


def test_build_model():
    pymc_interface = PyMCInterface()

    observations = [
        Observation(reward=1, model_name="model_1", child_idx=0, node=None),
        Observation(reward=0, model_name="model_1", child_idx=1, node=None),
        Observation(reward=0.2, model_name="model_1", child_idx=2, node=None),
        Observation(reward=0.2, model_name="model_1", child_idx=2, node=None),
        Observation(reward=0.8, model_name="model_1", child_idx=0, node=None),
        Observation(reward=0.8, model_name="model_2", child_idx=3, node=None),
    ]

    child_indices, rewards, coords = pymc_interface.preprocess_observations(
        observations, model_name="model_1"
    )

    # model_2 data is not included
    assert list(child_indices) == [0, 1, 2, 2, 0]
    assert list(rewards) == [1, 0, 0.2, 0.2, 0.8]
    assert coords == {"child_idx": [0, 1, 2]}

    model = pymc_interface.build_fitting_model(
        observations=observations,
        model_name="model_1",
    )
    with model:
        prior_checks = pm.sample_prior_predictive()

    save_model_figure_path = Path(__file__).parent / "figures" / "fitting_model.jpg"
    model.to_graphviz(save=save_model_figure_path)

    pred_model = pymc_interface.build_prediction_model(
        observations=observations,
        model_name="model_1",
    )
    with pred_model:
        prior_checks = pm.sample_prior_predictive()

    save_model_figure_path = Path(__file__).parent / "figures" / "prediction_model.jpg"
    pred_model.to_graphviz(save=save_model_figure_path)
