import gymnasium as gym
from gymnasium import spaces
import numpy as np
import pybullet as p
import os
from typing import Tuple, Optional, List
import time
from env.fetch_simulation import FetchSimulation
from env.fetch_controller import FetchController

class FetchEnv(gym.Env):
    metadata = {
        'render_modes': ['human', 'rgb_array'],
        'render.modes': ['human', 'rgb_array'],  # for backwards compatibility
        'render_fps': 30  # Adding render_fps to fix the warning
    }
    
    def __init__(self, gui=False, render_mode=None, debug=False, max_episode_steps=1000, 
                 use_egl=False, image_processor=None, ego_speed=0.4, tb2_speed=0.4, 
                 num_frames=8, **kwargs):
        super(FetchEnv, self).__init__()

        self.render_mode = render_mode
        
        if render_mode == 'human':
            gui = True

        self.debug = debug
        
        # Initialize the simulation
        self.max_episode_steps = max_episode_steps
        self.current_step = 0
        
        # Camera dimensions and frame stacking
        self.depth_width = 84
        self.depth_height = 84
        self.num_frames = num_frames  # Use the passed num_frames parameter
        self.num_depth_values = self.depth_width * self.depth_height * self.num_frames
        
        # Store image_processor (even if not used)
        self.image_processor = image_processor if image_processor is not None else lambda x: x
        
        # Action space: [linear_velocity, angular_velocity]
        self.action_space = spaces.Box(
            low=np.array([-1.0, -1.0]),  # min linear and angular velocity
            high=np.array([1.0, 1.0]),   # max linear and angular velocity
            dtype=np.float32
        )
        
        # Observation space: flattened depth frames only (like TurtlebotEnv2)
        self.observation_space = spaces.Box(
            low=0.0,
            high=1.0,
            shape=(self.num_depth_values,),
            dtype=np.float32
        )

        # Initialize depth frame stack
        self.depth_stack = []
        for _ in range(self.num_frames):  # Initialize with empty frames
            self.depth_stack.append(np.zeros((self.depth_height, self.depth_width), dtype=np.float32))
        
    
        # Create the simulation with GUI if render_mode is "human"
        self.sim = FetchSimulation(gui=gui, debug=debug)
        
        # Add episode tracking variables
        self.episode_reward = 0
        self.episode_length = 0
    
    def _get_info(self):
        """Get additional information"""
        raise NotImplementedError("Not implemented")
    
    def reset(self, seed=None, options=None):
        """Reset the environment to an initial state"""
        super().reset(seed=seed)
        
        # Reset episode tracking
        self.episode_reward = 0
        self.episode_length = 0
        
        # Reset the simulation
        depth_img = self.sim.reset()

        # Get initial depth image
        #_, depth_img, _ = self.sim.fetch_controller.get_camera_image()
        depth_img = np.array(depth_img, dtype=np.float32).reshape(self.depth_height, self.depth_width)
        
        # Normalize depth values to [0, 1] range
        depth_img = depth_img / self.sim.fetch_controller.far
        
        # Initialize depth stack with copies of the initial depth image
        self.depth_stack = []
        for _ in range(self.num_frames):
            self.depth_stack.append(depth_img.copy())
        
        # Stack frames along first dimension and flatten
        stacked_frames = np.stack(self.depth_stack, axis=0)
        depth_data = stacked_frames.flatten()

        # Return only the flattened depth data (like TurtlebotEnv2)
        return depth_data.astype(np.float32), {}
    
    def step(self, action):
        
        depth_img, reward, done, info = self.sim.step(action)
        # print(f"Action: {action}")

        # Get camera image and depth data from the simulation
        #_, depth_img, _ = self.sim.fetch_controller.get_camera_image()

        # Reshape depth image to correct dimensions
        depth_img = np.array(depth_img, dtype=np.float32).reshape(self.depth_height, self.depth_width)
        depth_img = depth_img / self.sim.fetch_controller.far
        # Update depth stack: add the new depth frame
        self.depth_stack.append(depth_img)
        if len(self.depth_stack) > self.num_frames:
            self.depth_stack.pop(0)

        # Ensure we always have the required number of frames
        while len(self.depth_stack) < self.num_frames:
            self.depth_stack.append(np.zeros((self.depth_height, self.depth_width), dtype=np.float32))

        stacked_frames = np.stack(self.depth_stack, axis=0)
        depth_data = stacked_frames.flatten()
        
        # Compute distance to goal and add to info
        robot_pos, _ = p.getBasePositionAndOrientation(self.sim.fetch_robot)
        distance_to_goal = abs(robot_pos[0] - self.sim.goal_position[0])
        info["distance_to_target"] = distance_to_goal
        
        self.episode_reward += reward
        self.episode_length += 1

        final_infos = {}
        if done:
            #print(f"Episode ended: {self.sim.end_reason}")
            info["episode"] = {
                "r": float(self.episode_reward),
                "l": int(self.episode_length)
            }
            info["reason"] = self.sim.end_reason
            info["stopped"] = done  # Add stopped flag
            final_infos["final_info"] = [info]

        # Return only the flattened depth data (like TurtlebotEnv2)
        return depth_data.astype(np.float32), reward, done, False, final_infos
        
    def render(self):
        if self.render_mode is None:
            return
            
        if self.render_mode == "rgb_array":
            rgb, depth, _ = self.sim.get_env_image()
            return rgb
        elif self.render_mode == "human":
            return None
        
    def seed(self, seed: Optional[int] = None):
        """
        Fix the seed of the environment. This also seeds the action space.
        """
        self.np_random, seed = seed.np_random(seed)
        self.action_space.seed(seed)
        return [seed]
    
    def close(self):
        """Clean up resources"""
        self.sim.close()
