import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import seeding
import numpy as np
from env.simulate2 import TurtlebotSim
from env.simulate_real import TurtlebotSim as TBReal
from typing import Tuple, Optional, List, Dict

class TurtlebotEnv2(gym.Env):
    """Custom Environment that follows gym interface"""
    metadata = {
        'render_modes': ['human', 'rgb_array'],
        'render.modes': ['human', 'rgb_array'],  # for backwards compatibility
        'render_fps': 30  # Adding render_fps to fix the warning
    }

    def __init__(self, gui=False, use_egl = False, ego_speed=0.4, tb2_speed=0.3, render_mode=None, real = False, image_processor = lambda x:x, num_frames = 4):
        super(TurtlebotEnv2, self).__init__()
        
        # Store render mode
        self.render_mode = render_mode
        
        # Camera dimensions and frame stacking
        self.depth_width = 84
        self.depth_height = 84
        self.num_frames = num_frames
        self.num_depth_values = self.depth_width * self.depth_height * self.num_frames
        
        # If render_mode is 'human', force GUI mode
        if render_mode == 'human':
            gui = True
            
        # Action space: [linear_velocity, angular_velocity]
        self.action_space = spaces.Box(
            low=np.array([-1.0, -1.0]),  # min linear and angular velocity
            high=np.array([1.0, 1.0]),   # max linear and angular velocity
            dtype=np.float32
        )
        
        # For depth-only observation space, we only need the depth values
        # Each depth value can range from 0 to far_plane (typically 1.0 in normalized depth)
        self.observation_space = spaces.Box(
            low=0.0,
            high=1.0,
            shape=(self.num_depth_values,),
            dtype=np.float32
        )
        self.image_processor = image_processor

        #print("REAL", ego_speed, tb2_speed)
        self.sim = TBReal(
                gui=gui, 
                use_egl=use_egl,
                ego_speed=ego_speed,
                tb2_speed=tb2_speed, 
                obs_size=self.observation_space.shape[0],
                action_size=self.action_space.shape[0]
            )
        # if real:
        #     print("REAL")
        #     # Initialize simulation
        #     self.sim = TBReal(
        #         gui=gui, 
        #         use_egl=use_egl,
        #         tb2_speed=tb2_speed, 
        #         tb3_speed=tb3_speed,
        #         obs_size=self.observation_space.shape[0],
        #         action_size=self.action_space.shape[0]
        #     )
        # else:
        #     # Initialize simulation
        #     self.sim = TurtlebotSim(
        #         gui=gui, 
        #         use_egl=use_egl,
        #         tb2_speed=tb2_speed, 
        #         tb3_speed=tb3_speed,
        #         obs_size=self.observation_space.shape[0],
        #         action_size=self.action_space.shape[0]
        #     )
        
        
        # Initialize depth frame stack
        self.depth_stack = []
        for _ in range(self.num_frames):  # Initialize with empty frames
            self.depth_stack.append(np.zeros((self.depth_height, self.depth_width), dtype=np.float32))
        
        # Add episode tracking variables
        self.episode_reward = 0
        self.episode_length = 0
        
    def seed(self, seed: Optional[int] = None):
        """Fix seed of environment

        In order to make the environment completely reproducible, call this function and seed the action space as well.
            env = gym.make(...)
            env.seed(seed)
            env.action_space.seed(seed)

        This function does not need to be used for this assignment, it is given only for reference.
        """
        self.np_random, seed = seeding.np_random(seed)
        self.action_space.seed(seed)
        return [seed]
    
    def step(self, action):
        """Execute one time step within the environment using only depth image frames."""
        # Step the simulation; we only use reward, done, and info from this.
        _, reward, done, info = self.sim.step(action)

        # Get camera image and depth data from the simulation
        _, depth_img, _ = self.sim.agent.get_camera_image()
        depth_img = self.image_processor(depth_img)

        # Reshape depth image to correct dimensions
        #depth_img = np.array(depth_img, dtype=np.float32).reshape(self.depth_height, self.depth_width)

        #depth_img = depth_img / self.sim.agent.far
    

        # Update depth stack: add the new depth frame
        self.depth_stack.append(depth_img)
        if len(self.depth_stack) > self.num_frames:
            self.depth_stack.pop(0)

        # Ensure we always have the required number of frames
        while len(self.depth_stack) < self.num_frames:
            self.depth_stack.append(np.zeros((self.depth_height, self.depth_width), dtype=np.float32))

        # Stack the frames along the first dimension (channel dimension for CNN)
        # This creates a shape of (num_frames, depth_height, depth_width)
        stacked_frames = np.stack(self.depth_stack, axis=0)
        
        # Flatten the stacked frames to a 1D array
        depth_data = stacked_frames.flatten()
        
        # Our full state is now only the depth data (stacked frames)
        full_state = depth_data.astype(np.float32)

        # Verify that the observation has the expected shape
        assert full_state.shape == self.observation_space.shape, \
            f"State shape mismatch: got {full_state.shape}, expected {self.observation_space.shape}"

        # Track episode stats
        self.episode_reward += reward
        self.episode_length += 1

        final_infos = {}
        if done:
            #print(f"Episode ended: {self.sim.end_reason}")
            info["episode"] = {
                "r": float(self.episode_reward),
                "l": int(self.episode_length)
            }
            final_infos["final_info"] = [info]

        return full_state, reward, done, False, final_infos

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Reset episode tracking
        self.episode_reward = 0
        self.episode_length = 0
        
        # Randomize turtlebot speeds for each episode
        tb2_speed = np.random.uniform(0.3, 0.65)
        tb3_speed = np.random.uniform(0.3, 0.65)
        
        # Update the speeds in the simulation
        self.sim.tb2_speed = tb2_speed
        self.sim.tb3_speed = tb3_speed
        
        # Reset simulation
        self.sim.reset()
        
        # Get initial depth image
        _, depth_img, _ = self.sim.agent.get_camera_image()
        depth_img = self.image_processor(depth_img)
        depth_img = np.array(depth_img, dtype=np.float32).reshape(self.depth_height, self.depth_width)
        
        # Normalize depth values to [0, 1] range
        depth_img = depth_img #/ self.sim.agent.far
        
        # Initialize depth stack with copies of the initial depth image
        self.depth_stack = []
        for _ in range(self.num_frames):
            self.depth_stack.append(depth_img.copy())
        
        # Stack frames along first dimension and flatten
        stacked_frames = np.stack(self.depth_stack, axis=0)
        depth_data = stacked_frames.flatten()
        
        # Ensure depth_data has correct shape
        assert depth_data.shape == (self.num_depth_values,), \
            f"Depth data shape mismatch: got {depth_data.shape}, expected {(self.num_depth_values,)}"
        
        return depth_data, {}

    def render(self):
        if self.render_mode is None:
            return
            
        if self.render_mode == "rgb_array":
            # Get camera image from simulation
            rgb = self.sim.get_env_image()
            return rgb
        
        elif self.render_mode == "human":
            # PyBullet's GUI mode handles rendering automatically
            return None

    def close(self):
        self.sim.close() 