"""A simple MPC controller for the terrain-mass environment."""
import torch
import cma
import random
from terrain_mass.environment import EnvironmentInstance


def dot(
    v1: tuple[float, float],
    v2: tuple[float, float],
) -> float:
    return v1[0]*v2[0] + v1[1]*v2[1]


def get_targets(
    state: torch.Tensor,
    target_position: tuple[float, float],
    environment: EnvironmentInstance,
) -> list[tuple[float, float]]:
    """Get the distance to the closest island that is towards the goal."""
    targets = list()
    targets.append(target_position)

    mass_position_tensor = environment.get_pos(state)
    mass_position = tuple(mass_position_tensor.tolist())
    relative_target_position = tuple(
        (torch.tensor(target_position) - mass_position_tensor).tolist()
    )

    for island in environment.islands:
        if island.is_inside(*mass_position):
            continue
        relative_island_tensor = torch.tensor(island.center) - mass_position_tensor
        relative_island = tuple(relative_island_tensor.tolist())
        island_ahead = dot(relative_island, relative_target_position) > 0.0
        if island_ahead:
            targets.append(island.center)

    return targets


def get_cost(
    initial_state: torch.Tensor,
    plan: torch.Tensor,
    environment: EnvironmentInstance,
    target_position: tuple[float, float],
    dt: float,
    success_distance_to_target: float,
) -> float:
    """Return the cost of the final position
    of the mass after executing the given plan
    from the given initial position.
    """
    x = initial_state
    tp = torch.tensor(target_position)
    i = 0
    costs = list()
    for action in plan:
        x = environment.step(
            x=x,
            action=action,
            dt=dt,
        )
        i += 1
        targets = get_targets(
            x,
            target_position,
            environment,
        )
        position = environment.get_pos(x)
        step_costs = [
            (torch.tensor(t) - position).norm().item()
            for t in targets
        ]
        costs.extend(step_costs)

        distance = (environment.get_pos(x)-tp).norm()
        if distance < success_distance_to_target:
            break
    slowness = i/len(plan)
    cost = sum(costs) + slowness
    return cost


def get_plan(
        initial_state: torch.Tensor,
        initial_candidate_plan: torch.Tensor,
        environment: EnvironmentInstance,
        target_position: tuple[float, float],
        iter_n: int,
        initial_stdev: float,
        dt: float,
        success_distance_to_target: float,
        action_min: float,
        action_max: float,
        seed: str,
        verbose: bool,
        ) -> torch.Tensor:
    """
    - iter_n: number of optimization iterations.
    - min_grad_norm: if smaller gradient, interrupt optimization
    """
    _random = random.Random(seed)
    bounds = (
        [action_min for _ in range(len(initial_candidate_plan))],
        [action_max for _ in range(len(initial_candidate_plan))],
     )
    es = cma.CMAEvolutionStrategy(
        initial_candidate_plan.flatten().tolist(),
        initial_stdev,
        dict(
            seed=int.from_bytes(_random.randbytes(3), 'big', signed=False),
            verbose=-3 if not verbose else 3,
            bounds=bounds,
        )
    )
    for _ in range(iter_n):
        solutions = es.ask()
        plans = [
            torch.tensor(
                s.tolist()
            ).reshape_as(
                initial_candidate_plan
            )
            for s in solutions
        ]
        costs = [
            get_cost(
                initial_state=initial_state,
                plan=plan,
                target_position=target_position,
                environment=environment,
                dt=dt,
                success_distance_to_target=success_distance_to_target,
            )
            for plan in plans
        ]
        es.tell(solutions, costs)

        if verbose:
            es.disp()

        if es.stop():
            break

    best_plan = torch.tensor(
        es.result[0].tolist()
    ).reshape_as(initial_candidate_plan)
    return best_plan
