from typing import Any

from gymnasium import Env
from gymnasium.error import DependencyNotInstalled
from gymnasium.spaces import Box
import numpy as np

from offline.types import ArrayLike, FloatArray

try:
    from offline.utils.suppress_warnings import pygame
except ImportError:
    pygame = None


ACTION_BOUND = 0.2
CRITERION = 0.1
LARGE_SIZE = 30
MAP_SCALE = 160
PYGAME_ERROR_MESSAGE = "pygame is not installed."
RADIUS = 0.1
SIZE = 3
WINDOW_SIZE = 512


class Navigate(Env[FloatArray, ArrayLike]):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 60}

    def __init__(
        self,
        size: float = 3,
        penalty: float = 0,
        render_mode: str = "rgb_array",
        version: int = 0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        if version not in (0, 1):
            raise NotImplementedError(f"Version {version}")
        if render_mode not in self.metadata["render_modes"]:
            raise NotImplementedError(f"Render mode: {render_mode}")
        self.action_space = Box(
            low=-ACTION_BOUND, high=ACTION_BOUND, shape=(2,), dtype=np.float32
        )
        observation_space = Box(
            low=0, high=float(size), shape=(2,), dtype=np.float32
        )
        self.observation_space = observation_space
        self._high = observation_space.high
        self._low = observation_space.low
        self._agent_location = np.zeros((2,), dtype=np.float32)
        self._start_location = np.zeros((2,), dtype=np.float32)
        self._target_location = np.ones((2,), dtype=np.float32) * size
        self._penalty_location = np.asarray([1, 0], dtype=np.float32) * size
        self.version = version
        self.window: Any = None
        self.clock: Any = None
        self.render_mode = render_mode
        self.size = size
        self.penalty = penalty

    def _reset(self):
        top_right = self._start_location + np.asarray([RADIUS, RADIUS])
        bottom_left = self._start_location - np.asarray([RADIUS, RADIUS])
        top_right = np.minimum(top_right, self._high, dtype=np.float32)
        bottom_left = np.maximum(bottom_left, self._low, dtype=np.float32)
        dimension = top_right - bottom_left
        self._agent_location = (
            bottom_left
            + self.np_random.random(size=(2,), dtype=np.float32) * dimension
        )
        if self.render_mode == "human":
            canvas = self._render_frame()
            self._render_postprocess(canvas)
        return self._agent_location, {}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        return self._reset()

    def step(self, action):
        if action.shape != self.action_space.shape:
            raise ValueError(
                "Unrecognized action shape. "
                f"Expected {self.action_space.shape} but got {action.shape}"
            )
        self._agent_location = np.clip(
            self._agent_location + action, 0, self.size, dtype=np.float32
        )
        terminated = self._compute_distance(self._target_location) < CRITERION
        reward: float = 0 if terminated else -1
        if self.version == 0:
            if (
                self._agent_location[0] > self.size / 2
                and self._agent_location[1] < self.size / 3
            ):
                reward += self.penalty
        elif self.version == 1:
            penalty_distance = (
                np.linalg.norm(self._agent_location - self._penalty_location)
                / self.size
            )
            penalty = 0.5 - penalty_distance
            reward += self.penalty * max(float(penalty), 0)
        if self.render_mode == "human":
            canvas = self._render_frame()
            self._render_postprocess(canvas)
        return self._agent_location, reward, terminated, False, {}

    def _render_frame(self):
        if pygame is None:
            raise DependencyNotInstalled(PYGAME_ERROR_MESSAGE)
        if self.render_mode == "human":
            if self.window is None:
                pygame.init()
                pygame.display.init()
                self.window = pygame.display.set_mode(
                    (WINDOW_SIZE, WINDOW_SIZE)
                )
            if self.clock is None:
                self.clock = pygame.time.Clock()
        canvas = pygame.Surface((WINDOW_SIZE, WINDOW_SIZE))
        canvas.fill((255, 255, 255))
        target_location = self._target_location * MAP_SCALE + RADIUS
        pygame.draw.circle(
            canvas, (255, 0, 0), target_location.tolist(), RADIUS
        )
        agent_location = self._agent_location * MAP_SCALE + RADIUS
        pygame.draw.circle(canvas, (0, 0, 255), agent_location.tolist(), RADIUS)
        return canvas

    def _render_postprocess(self, canvas):
        if pygame is None:
            raise DependencyNotInstalled(PYGAME_ERROR_MESSAGE)
        if self.window is None or self.clock is None:
            return np.transpose(
                np.asarray(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        self.window.blit(canvas, canvas.get_rect())
        pygame.event.pump()
        pygame.display.update()
        self.clock.tick(self.metadata["render_fps"])
        return None

    def render(self):
        if self.render_mode == "rgb_array":
            canvas = self._render_frame()
            return self._render_postprocess(canvas)
        return None

    def _compute_distance(self, goal: FloatArray):
        return np.linalg.norm(self._agent_location - goal)
