"""Trajectory optimization baseline for the button pushing environment
using Cross-entropy method."""
import torch
from examples.cem import cem
from rl_baseline import TightlyGuidedAntPasswordEnv
from quality import get_quality


Parameters = torch.Tensor
SolverTime = float
Log = list[tuple[SolverTime, Parameters]]


def cem_solver(
        target_password: tuple[int],
        max_timesteps: int,
        actuator_n: int,
        button_n: int,
        worker_n: int,
        password_so_far_encoding_size: int,
        sub_step_s: float,
        horizon_len: int,
        init_stdev: float,
        cem_inner_iter_n: int,
        timeout_s: float,
        sample_n: int,
        elite_n: int,
        verbose: bool,
        ) -> tuple[Log, bool]:
    # Note that cost is different from quality. This is because
    # we provide a dense signal with fewer local minima by using
    # the checkpoints from the RL baseline to guide the optimization process
    def get_cost(x: torch.Tensor) -> float:
        """Return the current cost of trajectory suffix `x`."""
        parameters = x[:max_timesteps]

        # We use the negative reward of the rl baseline as cost, as it
        # provides a dense signal
        env = TightlyGuidedAntPasswordEnv(
            target_password=target_password,
            button_n=button_n,
            actuator_n=actuator_n,
            episode_timestep_n=max_timesteps,
            sub_step_s=sub_step_s,
            password_so_far_encoding_size=password_so_far_encoding_size,
        )
        rewards = list()
        for t in range(len(parameters)):
            action = parameters[t]
            _, reward, done, _ = env.step(action)
            rewards.append(reward)
            if done:
                break
        cost = -sum(rewards)
        return cost

    def local_get_quality(x: torch.Tensor) -> float:
        """Return the current cost of trajectory suffix `x`."""
        return get_quality(
            parameters=x,
            password=target_password,
            num_buttons=button_n,
            sub_step_s=sub_step_s,
        )

    return cem(
        max_timesteps=max_timesteps,
        action_size=actuator_n,
        sample_n=sample_n,
        elite_n=elite_n,
        horizon_len=horizon_len,
        init_stdev=init_stdev,
        cem_inner_iter_n=cem_inner_iter_n,
        timeout_s=timeout_s,
        get_quality=local_get_quality,
        get_cost=get_cost,
        verbose=verbose,
        worker_n=worker_n,
    )


if __name__ == "__main__":
    pass
