import gym
from copy import deepcopy
from gym import spaces
import pygame
import numpy as np


class PlanSim2D(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, debug=False):
        self.window_size = 512  # The size of the PyGame window

        self.debug = debug  # Affects if bounding boxes are drawn around images
        self.max_shapes = 100

        self.n_shelves = 3

        self.observation_space = spaces.Dict(
            {
                "x": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "y": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "width": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "height": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
            }
        )

        self.action_space = spaces.Box(low=0, high=1, shape=(6,), dtype=np.float32)

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None

        self._xs = np.zeros(self.max_shapes, dtype=np.float32)  # x-coordinates for bottom-left corner
        self._ys = np.zeros(self.max_shapes, dtype=np.float32)  # y-coordinates for bottom-left corner
        self._widths = np.zeros(self.max_shapes, dtype=np.float32)  # widths
        self._heights = np.zeros(self.max_shapes, dtype=np.float32)  # heights
        self._colors = np.zeros((self.max_shapes, 3), dtype=np.int32)  # colors as (r,g,b) tuples
        self._paths = np.empty(self.max_shapes, dtype=object)  # image paths as strings relative to the current directory
        self.n_shapes = 0  # current number of shapes in the environment (e.g. self._xs[:n_shapes] are the x-coordinates of the shapes)

    def _get_obs(self):
        """
        Gets the current state of the environment as a dictionary

        Returns:
            {"x": np.ndarray, "y": np.ndarray, "width": np.ndarray, "height": np.ndarray}
        """
        return {
            "x": self._xs,
            "y": self._ys,
            "width": self._widths,
            "height": self._heights,
        }

    def check_collision(self):
        """
        Checks the scene if the most recent shape added collides with any other shape or is out of bounds
        Returns:
            the index of the first shape that the last shape added collides with (int)
            None if no collision (None)
            -1 if the added shape is out of bounds at all (int)
        """
        last_x = self._xs[self.n_shapes - 1]
        last_y = self._ys[self.n_shapes - 1]
        last_width = self._widths[self.n_shapes - 1]
        last_height = self._heights[self.n_shapes - 1]

        for i in range(self.n_shapes - 1):
            is_x_overlap = last_x < self._xs[i] + self._widths[i] and last_x + last_width > self._xs[i]
            is_y_overlap = last_y < self._ys[i] + self._heights[i] and last_y + last_height > self._ys[i]
            if is_x_overlap and is_y_overlap:
                return i
        
        # Check if out of bounds
        is_x_outside = last_x < 0 or last_x + last_width > 1
        is_y_outside = last_y < 0 or last_y + last_height > 1
        if is_x_outside or is_y_outside:
            return -1
        return None

    def _get_info(self):
        """
        Gets a dictionary containing collision info at key "collision", which has value
        Returns:
            {"collision": int or None}
                Index of the first shape that the last shape added collides with
                None if no collision
        """

        return {"collision": self.check_collision()}

    def reset(self, seed=None):
        """
        Resets the environment to its initial, empty state. All objects/actions are removed
        Parameters:
            seed (int)
                The seed to use for the environment's random number generator
        
        Returns:
            observation (dict: {"x": np.ndarray, "y": np.ndarray, "width": np.ndarray, "height": np.ndarray})
                The observation of the environment after the reset
            info (dict: {"collision": int or None})
                Additional information about the environment
        """

        super().reset(seed=seed)

        self._xs = np.zeros(self.max_shapes, dtype=np.float32)
        self._ys = np.zeros(self.max_shapes, dtype=np.float32)
        self._widths = np.zeros(self.max_shapes, dtype=np.float32)
        self._heights = np.zeros(self.max_shapes, dtype=np.float32)
        self._colors = np.zeros((self.max_shapes, 3), dtype=np.int32)
        self._paths = np.empty(self.max_shapes, dtype=object)
        self.n_shapes = 0

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def step(self, action):
        """
        Steps the environment forward by one timestep.
        Parameters:
            action (np.ndarray)
                The action to take in the environment
                Must contain at least 4 elements: [x, y, width, height]
                Optionally, it can contain a 5th element: color, which is a 3-tuple
                Optionally, it can contain a 6th element: path to an image file
                The color is ignored if a path is provided
        Returns:
            observation (dict: {"x": np.ndarray, "y": np.ndarray, "width": np.ndarray, "height": np.ndarray})
                The observation of the environment after the action
            reward (float)
                The reward after the action (fixed to 0)
            terminated (bool)
                Whether the episode is done (fixed to False)
            trunc (bool)
                Whether the episode was truncated (fixed to False)
            info (dict: {"collision": int or None})
                Additional information about the environment
        """

        self._xs[self.n_shapes] = action[0]
        self._ys[self.n_shapes] = action[1]
        self._widths[self.n_shapes] = action[2]
        self._heights[self.n_shapes] = action[3]
        self._colors[self.n_shapes] = (
            action[4] if len(action) == 5 else np.random.randint(0, 256, 3)
        )
        if len(action) == 6:
            self._paths[self.n_shapes] = action[5]
        self.n_shapes += 1

        terminated = False
        truncated = False
        reward = 0
        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, terminated, truncated, info

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))

        shelf_height = self.window_size / self.n_shelves

        for i in range(self.n_shapes):
            x = self._xs[i] * self.window_size
            y = self.window_size - self._ys[i] * self.window_size
            w = self._widths[i] * self.window_size
            h = self._heights[i] * self.window_size

            if self._paths[i] is not None:  # Draw image from path
                img = pygame.image.load(self._paths[i])
                img = pygame.transform.scale(img, (int(w), int(h)))
                if self.debug:
                    pygame.draw.rect(img, (255, 0, 0), [0, 0, w, h], 1)
                canvas.blit(img, (x, y - h))
            else:  # Draw rectangle
                r = self._colors[i][0]
                g = self._colors[i][1]
                b = self._colors[i][2]
                color_tuple = (r, g, b)
                pygame.draw.rect(
                    canvas,
                    color_tuple,
                    pygame.Rect(
                        (
                            x,
                            y - h,
                        ),
                        (w, h),
                    ),
                )

        # Finally, draw shelf lines
        for shelf in range(1, self.n_shelves):
            pygame.draw.line(
                canvas,
                0,
                (0, shelf * shelf_height),
                (self.window_size, shelf * shelf_height),
                width=3,
            )

        if self.render_mode == "human":
            # The following line copies drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the predefined framerate.
            # The following line will automatically add a delay to keep the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def __deepcopy__(self, memo):
        """Gets a deep copy of the environment.

        This function is necessary to deepcopy all but the PyGame window and clock

        Parameters:
            memo (dict)
                A dictionary that maps the id of objects to the objects themselves

        Returns:
            result (gym.Env)
                A deep copy of the environment.
        """
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            if k in ["window", "clock"]:
                setattr(result, k, None)
            else:
                setattr(result, k, deepcopy(v, memo))
        return result
