"""Measure the quality of a trajectory."""
import torch
from dynamics import simulation
from rl_baseline import TightlyGuidedAntPasswordEnv


def get_quality(
        parameters: torch.Tensor,
        password: tuple[int, ...],
        num_buttons: int,
        sub_step_s: float,
        ) -> float:
    """Return the quality of the given actions relative to the task
    induced by the given password.

    Quality goes from 0 to 1.
    """
    # We use the negative reward of the rl baseline as cost, as it
    # provides a dense signal
    env = TightlyGuidedAntPasswordEnv(
        target_password=password,
        button_n=num_buttons,
        actuator_n=8,
        episode_timestep_n=10000000,
        sub_step_s=sub_step_s,
        password_so_far_encoding_size=5,
    )
    rewards = list()
    for t in range(len(parameters)):
        action = parameters[t]
        _, reward, done, _ = env.step(action)
        rewards.append(reward)
        if done:
            break

    # If we reached the goal, we get maximum score
    if env.is_success:
        env.close()
        return 1.0

    if len(env.target_password) == 0:
        env.close()
        return 0.0

    # If more buttons were pressed, simply return -1
    activated_buttons = env.password_so_far
    if len(activated_buttons) > len(password):
        return -1

    # Otherwise, return fraction of password correct
    correct_buttons = 0
    for target_button, actual_button in zip(password, activated_buttons):
        if target_button != actual_button:
            break
        correct_buttons += 1

    env.close()
    return correct_buttons/(len(password)+1)
