"""Dynamics for the Ant task."""
import torch
from environment import get_simulation
from environment import AntSimulation


def distance(p1: tuple[float, float], p2: tuple[float, float]) -> float:
    """Return the distance between the two points."""
    return ((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)**(1/2)


def simulation_(
        parameters: torch.Tensor,
        env: AntSimulation,
        ):
    """Run an ant simulation using the given parameters as control signal. The
    `parameters` define the time-varying control signal at each step and is
    expected to be of shape `(num_timesteps, 8)`."""
    # Step simulation
    T = len(parameters)
    debug_is = list()
    for t in range(T):
        ant_positions = env.step(parameters[t].detach().numpy())

        # Button-platform logic
        for ant_position in ant_positions:
            for i, button in enumerate(env.buttons):
                px, py, pr = button.x, button.y, button.r
                if distance(ant_position, (px, py)) < pr:
                    env.activate_button(i)
                    debug_is.append((i, t))


def simulation(
        parameters: torch.Tensor,
        password: tuple[int, ...],
        num_buttons: int,
        sub_step_s: float,
        ) -> AntSimulation:
    """Run an ant simulation using the given parameters as control signal. The
    `parameters` define the time-varying control signal at each step and is
    expected to be of shape `(num_timesteps, 8)`."""
    # Initialize the environment
    env = get_simulation(
        animation_fps=None,
        password=password,
        num_buttons=num_buttons,
        sub_step_s=sub_step_s,
        )
    env.reset()

    # Step simulation
    T = len(parameters)
    debug_is = list()
    for t in range(T):
        ant_positions = env.step(parameters[t].detach().numpy())

        # Button-platform logic
        for ant_position in ant_positions:
            for i, button in enumerate(env.buttons):
                px, py, pr = button.x, button.y, button.r
                if distance(ant_position, (px, py)) < pr:
                    env.activate_button(i)
                    debug_is.append((i, t))
    return env
