import gym
from gym import spaces
import numpy as np
import pygame
from pygame.locals import *
import time
import datetime

class PlanSim3D(gym.Env):
    metadata = {"render_modes": [], "render_fps": 30}

    def __init__(self, render_mode=None, window_size=(800, 600), fridge=None, log=False):
        super(PlanSim3D, self).__init__()

        self.max_shapes = 100

        if fridge is None:
            print("Fridge point cloud not provided. Using empty point cloud.")
            self.fridge_pcl = []
        else:
            fridge = fridge.reshape(-1,3)
            self.fridge_pcl = fridge

        if log:
            # Save the point cloud to a file named after the current time
            name = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            np.save(name, self.fridge_pcl)

        self.observation_space = spaces.Dict(
            {
                "x": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "y": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "z": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "width": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "height": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
                "depth": spaces.Box(
                    low=0,
                    high=1,
                    shape=(self.max_shapes,),
                    dtype=np.float32,
                ),
            }
        )

        self.action_space = spaces.Box(
            low=0, high=1, shape=(6,), dtype=np.float32
        )  # x, y, z, width, height, depth
        self.render_mode = render_mode
        self.window_size = window_size

        self._xs = np.zeros(self.max_shapes, dtype=np.float32)  # x-coordinates for bottom-left corner
        self._ys = np.zeros(self.max_shapes, dtype=np.float32)  # y-coordinates for bottom-left corner
        self._zs = np.zeros(self.max_shapes, dtype=np.float32)  # z-coordinates for bottom-left corner
        self._widths = np.zeros(self.max_shapes, dtype=np.float32)  # widths
        self._heights = np.zeros(self.max_shapes, dtype=np.float32)  # heights
        self._depths = np.zeros(self.max_shapes, dtype=np.float32)  # depths
        self._colors = np.zeros((self.max_shapes, 3), dtype=np.int32)  # colors as (r,g,b) tuples
        self._paths = np.empty(self.max_shapes, dtype=object)  # image paths as strings relative to the current directory

        self.n_shapes = 0 # current number of shapes in the environment (e.g. self._xs[:n_shapes] are the x-coordinates of the shapes)

    def _get_obs(self):
        """
        Gets the current state of the environment as a dictionary

        Returns:
            {"x": np.ndarray, "y": np.ndarray, "width": np.ndarray, "height": np.ndarray}
        """
        return {
            "x": self._xs,
            "y": self._ys,
            "z": self._zs,
            "width": self._widths,
            "height": self._heights,
            "depth": self._depths,
        }

    def step(self, action):
        """
        Steps the environment forward by one timestep
        Parameters:
            action (np.ndarray)
                The action to take in the environment
                Must contain 6 elements: [x, y, z, width, height, depth]
                All elements must be in the range [0,1]
        Returns:
            observation (dict)
                The observation of the environment after the action
            reward (float)
                The reward after the action
            terminated (bool)
                Whether the episode is done
            trunc (bool)
                Whether the episode was truncated
            info (dict)
                Additional information about the environment
        """
        self._xs[self.n_shapes] = action[0]
        self._ys[self.n_shapes] = action[1]
        self._zs[self.n_shapes] = action[2]
        self._widths[self.n_shapes] = abs(action[3])
        self._heights[self.n_shapes] = abs(action[4])
        self._depths[self.n_shapes] = abs(action[5])
        self.n_shapes += 1

        observation = self._get_obs()
        reward = 0
        done = False
        info = self._get_info()

        return observation, reward, done, False, info

    def check_collision(self):
        """
        Checks the scene if the most recent shape added collides with any other shape
        Returns:
            the index of the first shape that the last shape added collides with (int)
            -1 if collision with fridge
            None if no collision (None)
        """
        if self.fridge_pcl == []:
            return None
        
        if self.n_shapes == 0:
            return None
        
        last_x = self._xs[self.n_shapes - 1]
        last_y = self._ys[self.n_shapes - 1]
        last_z = self._zs[self.n_shapes - 1]
        last_width = self._widths[self.n_shapes - 1]
        last_height = self._heights[self.n_shapes - 1]
        last_depth = self._depths[self.n_shapes - 1]

        # Check for collision with fridge
        lower_bound = np.array([last_x, last_y, last_z + 1])
        upper_bound = np.array([last_x + last_width, last_y + last_height, last_z - last_depth])
        lower_bound,upper_bound = np.minimum(lower_bound,upper_bound),np.maximum(lower_bound,upper_bound)

        ############################
        # Numpy Intersection
        if np.any(np.all((lower_bound < self.fridge_pcl) & (self.fridge_pcl < upper_bound), axis=1)):
            return -1
        ############################

        # Check for collision with other objects
        for i in range(self.n_shapes - 1):
            is_x_overlap = (last_x < self._xs[i] + self._widths[i]) and (last_x + last_width > self._xs[i])
            is_y_overlap = (last_y < self._ys[i] + self._heights[i]) and (last_y + last_height > self._ys[i])
            is_z_overlap = (last_z < self._zs[i] + self._depths[i]) and (last_z + last_depth > self._zs[i])
            if is_x_overlap and is_y_overlap and is_z_overlap:
                return i

        return None

    def _get_info(self):
        """
        Gets a dictionary containing collision info at key "collision", which has value
        Returns:
            {"collision": int or None}
                Index of the first shape that the last shape added collides with
                None if no collision
        """
        return {"collision": self.check_collision()}

    def reset(self, seed=None):
        """
        Resets the environment to its initial, empty state. All objects/actions are removed
        Parameters:
            seed (int)
                The seed to use for the random number generator
        Returns:
            observation (dict: {"x": np.ndarray, "y": np.ndarray, "z": np.ndarray, "width": np.ndarray, "height": np.ndarray, "depth": np.ndarray})
                The observation of the environment after the reset
            info (dict: {"collision": int or None})
                Additional information about the environment
        """
        super().reset(seed=seed)

        self._xs = np.zeros(self.max_shapes, dtype=np.float32)
        self._ys = np.zeros(self.max_shapes, dtype=np.float32)
        self._zs = np.zeros(self.max_shapes, dtype=np.float32)
        self._widths = np.zeros(self.max_shapes, dtype=np.float32)
        self._heights = np.zeros(self.max_shapes, dtype=np.float32)
        self._depths = np.zeros(self.max_shapes, dtype=np.float32)
        self.n_shapes = 0

        observation = self._get_obs()
        info = self._get_info()

        return observation, info

    def render(self, mode="human"):
        """Rendering not currently supported for 3D Simulator"""
        pass

    def close(self):
        pygame.quit()
