"""Cross-entropy method solver for the maze environment."""
from tasks import TaskImpulse
from tasks import get_quality
from examples.cem import cem
from rl_baseline import AssistedMazeEnv
import torch


def custom_cem_solver(
        task: TaskImpulse,
        sample_n: int,
        elite_n: int,
        horizon_len: int,
        init_stdev: float,
        cem_inner_iter_n: int,
        timeout_s: float,
        verbose: bool,
        ) -> tuple[list[tuple[float, torch.Tensor]], bool]:
    """
    This implements CEM_MPC as described in:
    Sample-efficient Cross-Entropy Method for Real-time Planning, 2020
    """
    # Note that cost is different from quality. This is because
    # we provide a dense signal with fewer local minima by using
    # the checkpoints from the RL baseline to guide the optimization process
    def get_cost(x: torch.Tensor) -> float:
        """Return the current cost of trajectory suffix `x`."""
        parameters = x

        ## Log candidate parameters
        # (Not logging because there are too many evals during optimization)
        #current_t = time.time()-start_t
        #log.append((current_t, torch.tensor(parameters)))

        # We use the negative reward of the rl baseline as cost, as it
        # provides a dense signal
        env = AssistedMazeEnv(
            initial_state=task.initial_state,
            restart_timestep_n=task.max_timesteps,
            goal=task.goal_circle,
            checkpoints=task.checkpoints,
            checkpoint_sat_distance=task.goal_circle.radius,
        )
        rewards = list()
        for t in range(len(parameters)):
            action = parameters[t]
            _, reward, done, _ = env.step(action)
            rewards.append(reward)
            if done:
                break
        return -sum(rewards)

    def local_get_quality(x: torch.Tensor) -> float:
        return get_quality(x, task)

    return cem(
        max_timesteps=task.max_timesteps,
        action_size=2,
        sample_n=sample_n,
        elite_n=elite_n,
        horizon_len=horizon_len,
        init_stdev=init_stdev,
        cem_inner_iter_n=cem_inner_iter_n,
        timeout_s=timeout_s,
        get_quality=local_get_quality,
        get_cost=get_cost,
        worker_n=1,
        verbose=verbose,
    )
