"""Visualize the result of running the MPC controller on a
simple system."""
from terrain_mass.plotting import plot_animation
from swmpo_experiments.terrain_mass_utils.mpc.planner import get_plan
from terrain_mass.task import DT as dt
from terrain_mass.task import get_example_task
from pathlib import Path
import torch


def main(): # Define environment
    task = get_example_task("asd")
    environment = task.environment

    # Initialize plan
    mpc_plan_len = 60*1
    plan = torch.tensor([
        (0.0, 0.0)
        for _ in range(mpc_plan_len)
    ])

    # Call MPC in a loop
    states = [environment.get_initial_state()]
    simulation_step_n = 60*2
    tp = torch.tensor(task.target_position)
    for i in range(simulation_step_n):
        print(f"MPC step {i+1}/{simulation_step_n}")
        x = states[-1]

        # Optimize local MPC plan
        actions = get_plan(
            initial_state=x,
            initial_candidate_plan=plan,
            environment=environment,
            iter_n=100,
            target_position=task.target_position,
            dt=dt,
            initial_stdev=0.2,
            success_distance_to_target=task.success_distance_to_target,
            action_min=task.environment.action_min,
            action_max=task.environment.action_max,
            seed="asd",
            verbose=True,
        )

        # Update MPC plan looping first action
        # (this induces a bias towards periodical motion
        # but it doesn't matter too much)
        action = actions[0]
        actions = torch.cat((actions[1:], actions[0:1]))

        # Step simulation
        next_state = environment.step(
            x=x,
            action=action,
            dt=dt,
        )
        states.append(next_state)

        distance = (environment.get_pos(x)-tp).norm()
        if distance < task.success_distance_to_target:
            break

    # PLot animation
    output_path = Path()/"test.mp4"
    plot_animation(
        states=states,
        environment=environment,
        fps=60,
        output_path=output_path,
        target_position=task.target_position,
        mass_radius=task.success_distance_to_target
    )
    print(f"Wrote {output_path}")


if __name__ == "__main__":
    main()
