"""Gymnasium wrapper for the environment."""
import gymnasium as gym
from gymnasium import spaces
from terrain_mass.task import get_example_task
from terrain_mass.environment import EnvironmentInstance
from typing import Any
import numpy as np
import random
import torch
import math
from PIL import Image
import tempfile
from pathlib import Path

from terrain_mass.plotting import plot_state
from terrain_mass.plotting import get_bounds
from terrain_mass.environment import EnvironmentInstance


def get_distance_to_target(state: torch.Tensor, target_position: torch.Tensor):
    position = EnvironmentInstance.get_pos(state)
    distance = (position - target_position).norm()
    return distance


def get_closest_island_position(
    state: torch.Tensor,
    environment: EnvironmentInstance,
) -> torch.Tensor:
    current_position = environment.get_pos(state)
    best_position = torch.tensor([0.0, 0.0])
    best_distance = math.inf
    for island in environment.islands:
        island_position = torch.tensor(island.center)
        distance = (island_position - current_position).norm()
        if distance < best_distance:
            best_distance = distance
            best_position = island_position
    return best_position


def get_reward(
        x: torch.Tensor,
        xn: torch.Tensor,
        success_distance_to_target: float,
        success_reward: torch.Tensor,
        target_position: torch.Tensor,
        max_distance_to_target: float,
        time_penalty: float,
        ) -> torch.Tensor:
    """Reward given a state, action and next state, given a target position."""
    distance_prev = get_distance_to_target(x, target_position)
    distance = get_distance_to_target(xn, target_position)
    if distance.item() < success_distance_to_target:
        return success_reward
    if distance.item() > max_distance_to_target:
        return -success_reward
    reward = distance_prev - distance
    return reward - time_penalty


class TerrainMassEnv(gym.Env):
    metadata = {"render_modes": ["rgb_array"], "render_fps": 60}

    def __init__(
            self,
            render_mode=None,
            success_reward: float = 10.0,
            animation_padding: float = 2.0,
            episode_max_steps: int = 300,
            plotting_mass_radius: float = 0.1,
            max_distance_margin: float = 100.0,
            time_penalty: float = 0.001,
            ):
        seed = None
        if seed is None:
            seed = str(random.random())
        _random = random.Random(seed)
        self.task = get_example_task(str(_random.random()))
        self.instance = self.task.environment
        #obs_size = len(self.task.environment.get_initial_state())
        obs_size = 10
        self.observation_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(obs_size,),
            dtype=np.float32,
        )
        self.action_space = spaces.Box(
            low=self.instance.action_min,
            high=self.instance.action_max,
            shape=(2,),
            dtype=np.float32,
        )
        self.dt = self.task.dt
        self.success_distance_to_target = self.task.success_distance_to_target
        self.success_reward = success_reward
        self.animation_padding = animation_padding
        self.episode_max_steps = episode_max_steps
        self.plotting_mass_radius = plotting_mass_radius
        self.max_distance_margin = max_distance_margin
        self.time_penalty = time_penalty

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

    def _get_observation(self) -> np.ndarray:
        observation_scaling = 15.0
        state = self.current_state
        state = state/observation_scaling
        state = state.clip(min=-1.0, max=1.0)
        state = state.detach()
        target_position = self.target_position/observation_scaling
        target_position = target_position.clip(min=-1.0, max=1.0)
        target_position = target_position.detach()
        current_position = self.instance.get_pos(self.current_state)
        relative_target = (self.target_position - current_position)
        relative_target = relative_target/relative_target.norm()
        relative_target = relative_target.detach()
        closest_island_position = get_closest_island_position(
            state=self.current_state,
            environment=self.task.environment,
        )
        relative_island = (closest_island_position - current_position)
        relative_island = relative_island/observation_scaling
        observation = [
            *(state.tolist()),
            *(target_position.tolist()),
            *(relative_target.tolist()),
            *(relative_island.tolist()),
        ]
        return np.array(observation, dtype=np.float32)

    def reset(
            self,
            seed=None,
            options=None,
            ) -> tuple[np.ndarray, dict[str, Any]]:
        super().reset(seed=seed)

        # Sample initial state
        initial_state = self.task.environment.get_initial_state()

        # Sample target position
        target_position = torch.tensor(self.task.target_position)

        self.current_state = initial_state
        self.target_position = target_position
        self.episode_steps = 0
        self.max_distance_to_target = get_distance_to_target(
            initial_state,
            target_position,
        ).item() + self.max_distance_margin

        info = dict(state=self.current_state.tolist())
        return self._get_observation(), info

    def _is_out_of_bounds(self) -> bool:
        distance = get_distance_to_target(
            self.current_state,
            self.target_position,
        ).item()
        out_of_bounds = distance > self.max_distance_to_target
        return out_of_bounds

    def _is_terminated(self) -> bool:
        distance = get_distance_to_target(
            self.current_state,
            self.target_position,
        ).item()
        is_terminated = any((
            distance < self.success_distance_to_target,
            self.episode_steps > self.episode_max_steps,
        ))
        return is_terminated

    def _is_truncated(self) -> bool:
        is_truncated = any((
            self.episode_steps > self.episode_max_steps,
            self._is_out_of_bounds(),
        ))
        return is_truncated

    def step(self, action):
        self.episode_steps += 1
        next_state = self.instance.step(
            x=self.current_state,
            action=torch.from_numpy(action),
            dt=self.dt,
        )
        reward = get_reward(
            x=self.current_state,
            xn=next_state,
            success_distance_to_target=self.success_distance_to_target,
            success_reward=torch.tensor(self.success_reward),
            target_position=self.target_position,
            max_distance_to_target=self.max_distance_to_target,
            time_penalty=self.time_penalty,
        ).item()
        self.current_state = next_state
        terminated = self._is_terminated()
        truncated = self._is_truncated()
        info = dict(state=next_state.tolist())
        observation = self._get_observation()
        return observation, reward, terminated, truncated, info

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()
        raise ValueError(f"Unknown render mode {self.render_mode}!")

    def _render_frame(self) -> np.ndarray:
        with tempfile.NamedTemporaryFile(suffix=".png") as fp:
            target_position = tuple(self.target_position.tolist())
            bounds = get_bounds(
                [self.current_state],
                environment=self.instance,
                target_position=target_position,
            )
            plot_state(
                x=self.current_state,
                x_min=bounds.x_min-self.animation_padding,
                x_max=bounds.x_max+self.animation_padding,
                y_min=bounds.y_min-self.animation_padding,
                y_max=bounds.y_max+self.animation_padding,
                mass_radius=self.plotting_mass_radius,
                environment=self.instance,
                target_position=target_position,
                output_path=Path(fp.name),
            )
            im = Image.open(fp.name).convert("RGB")
            a = np.asarray(im)
        return a
