"""Test cases for the exploit bilevel planning explorer class."""

import pytest

from predicators import utils
from predicators.envs.cover import CoverEnv
from predicators.explorers import BaseExplorer, create_explorer
from predicators.ground_truth_nsrts import get_gt_nsrts
from predicators.option_model import _OracleOptionModel


def test_exploit_bilevel_planning_explorer():
    """Tests for ExploitBilevelPlanningExplorer class."""
    utils.reset_config({
        "env": "cover",
        "explorer": "exploit_planning",
    })
    env = CoverEnv()
    nsrts = get_gt_nsrts(env.get_name(), env.predicates, env.options)
    option_model = _OracleOptionModel(env)
    train_tasks = env.get_train_tasks()
    explorer = create_explorer("exploit_planning", env.predicates, env.options,
                               env.types, env.action_space, train_tasks, nsrts,
                               option_model)
    task_idx = 0
    task = train_tasks[task_idx]
    policy, termination_function = explorer.get_exploration_strategy(
        task_idx, 500)
    traj, _ = utils.run_policy(
        policy,
        env,
        "train",
        task_idx,
        termination_function,
        max_num_steps=1000,
    )
    final_state = traj.states[-1]
    assert termination_function(final_state)
    assert task.goal_holds(final_state)

    # Test timeout. Should fall back.

    class _DummyExplorer(BaseExplorer):

        @classmethod
        def get_name(cls):
            return "dummy"

        def get_exploration_strategy(self, train_task_idx, timeout):
            raise NotImplementedError("Dummy explorer called")

    dummy_explorer = _DummyExplorer(env.predicates, env.options, env.types,
                                    env.action_space, train_tasks)
    assert dummy_explorer.get_name() == "dummy"

    explorer._fallback_explorer = dummy_explorer  # pylint: disable=protected-access

    with pytest.raises(NotImplementedError) as e:
        explorer.get_exploration_strategy(task_idx, -1)
    assert "Dummy explorer called" in str(e)
