"""Cross-entropy method solver for the marble drop environment."""
from examples.marble_drop.environment import simulation
from examples.cem import cem
import torch
from pylic.code_transformations import get_tape
from examples.marble_drop.trajectory_pylic import FPS_PER_ACTION
from examples.marble_drop.environment import Action
from examples.marble_drop.environment import Puzzle
from examples.marble_drop.environment import get_puzzle_button_n


def get_quality(actions: list[Action], puzzle: Puzzle) -> float:
    """Return the quality of the given actions. Quality is defined
    as the fraction of buttons that were touched."""
    timestep_n = int(FPS_PER_ACTION*(len(actions)+1))
    pressed_buttons = simulation(actions, puzzle, None, timestep_n)
    score = len(pressed_buttons)/max(1, get_puzzle_button_n(puzzle))
    return score


def custom_cem_solver(
        puzzle: Puzzle,
        timeout_s: float,
        worker_n: int,
        random_restart_n: int,
        ) -> tuple[list[tuple[float, torch.Tensor]], bool]:
    """
    This implements CEM_MPC as described in:
    Sample-efficient Cross-Entropy Method for Real-time Planning, 2020
    """
    def local_get_quality(x: torch.Tensor) -> float:
        return get_quality(x.tolist(), puzzle)

    # Note that cost is different from quality. This is because
    # we provide a dense signal with fewer local minima by using
    # the checkpoints from the RL baseline to guide the optimization process
    def get_cost(x: torch.Tensor) -> float:
        return -local_get_quality(x)

    max_actions = get_puzzle_button_n(puzzle)
    log, is_solved = cem(
        max_timesteps=max_actions,
        action_size=4,
        sample_n=worker_n*2,
        elite_n=4,
        horizon_len=3,
        init_stdev=1.0,
        cem_inner_iter_n=64,
        timeout_s=timeout_s,
        get_quality=local_get_quality,
        get_cost=get_cost,
        worker_n=worker_n,
        verbose=True,
    )

    # Log from CEM is in transposed format
    fixed_log = [
            (parameters, time_t)
            for (time_t, parameters) in log
    ]
    return fixed_log, is_solved
