"""Defines discrete-time dynamics for the marble + line segment segments
system. These could easily be included in the simulation's main code, but are
instead written separately to minimize the amount of code in the simulation,
which is intended as an illustrative example of pylic's capabilities."""
import torch
from trajectory_dynamics import State
from trajectory_dynamics import collision_dynamics
from trajectory_dynamics import collision_soft_predicate
from trajectory_dynamics import normal_dynamics_force


def get_observation(
        state: State,
        t: int,
        T: int,
        coordinate_scale: float,
        ) -> torch.Tensor:
    """Convert the state to a numpy array with elements between -1 and 1."""
    # Clip the 
    position = state.marble.position*coordinate_scale
    velocity = state.marble.velocity
    values = [
        # Normalized relative target position
        position[0],
        position[1],
        velocity[0],
        velocity[1],
        t/T,
    ]
    return torch.tensor(values)


def get_coordinate_scale(state: State) -> float:
    """Get a scalar factor to scale coordinate to obtain observations from
    states."""
    # Extract the largest coordinate among the segments in the state
    coordinate_values = [
        [
            *segment.p1,
            *segment.p2,
        ]
        for segment in state.segments
    ]
    if len(coordinate_values) == 0:
        return 1.0
    max_coordinate = torch.tensor(coordinate_values).flatten().abs().max().item()
    return 1/max_coordinate


def simulation(
        policy: (torch.nn.Module|torch.Tensor),
        state: State,
        step_n: int,
        ) -> State:
    """Run the simulation with the given closed-loop control
    signal and return the final state."""
    coordinate_scale = get_coordinate_scale(state)
    for t in range(step_n):
        # Compute next state with no collisions
        if isinstance(policy, torch.Tensor):
            # Open loop, i.e. trajectory
            action = policy[t]
        else:
            observation = get_observation(
                state=state,
                t=t,
                T=step_n,
                coordinate_scale=coordinate_scale,
            )
            action = policy(observation)
        action = action.clip(min=-1, max=1)*state.impulse_scale
        next_state = normal_dynamics_force(
            state,
            action,
            state.drag_constant,
            state.dt
        )

        # Check for collisions
        for i in range(len(state.segments)):
            v = collision_soft_predicate(state, state.dt, i)
            if v > 0:
                # If collided, adjust next state
                next_state = collision_dynamics(
                    state,
                    i,
                    state.dt,
                    state.coefficient_of_restitution
                )

        # Update state
        state = next_state
    return state


def get_policy_actions(
        policy: torch.nn.Module,
        state: State,
        step_n: int,
        ) -> list[torch.Tensor]:
    """Return the action sequence induced by running the given policy
    for the given number of steps, starting on the given starting state."""
    actions = list()
    coordinate_scale = get_coordinate_scale(state)
    for t in range(step_n):
        action = policy(get_observation(
            state=state,
            t=t,
            T=step_n,
            coordinate_scale=coordinate_scale,
        ))
        actions.append(action)
        state = simulation(policy, state, step_n=1)
    return actions
