"""
Base Ant environment classes.
"""

import numpy as np
import mujoco as mj
import os
import gymnasium as gym
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.envs.mujoco.ant_v4 import AntEnv
from gymnasium.spaces import Box
from gymnasium import utils
from warnings import filterwarnings
from utils import *
from dotenv import load_dotenv

import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
import gymnasium.spaces as spaces

# Load environment variables
load_dotenv()
BASE_XML_DIR = os.getenv('BASE_XML_DIR')

# Filter warnings
filterwarnings(action="ignore", category=DeprecationWarning, 
               message="`np.bool8` is a deprecated alias for `np.bool_`")

class Ant(AntEnv):
    """
    Base Ant environment class.
    Implements common functionality for all ant environments.
    """
    def __init__(self, test_mode=False, **kwargs):
        super(Ant, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
        
    def step(self, action):
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()
        
        # Calculate reward using the implemented method
        reward = self.calculate_reward(xy_position_before, xy_position_after)
         
        position_torso = self.get_body_com("torso").copy()
        observation = self._get_obs()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "rewards": reward,
            "ant_x_position": position_torso[0],
            "ant_y_position": position_torso[1],
            "x_position": position_torso[0],
            "y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }
        
        if self.render_mode == "human":
            self.render()
            
        terminated = False
        contact_detected = self.data.sensordata[0]
        if contact_detected > 0.0:
            terminated = True
        return observation, reward, terminated, False, info

    def get_obs(self):
        return self._get_obs() 

# without goal in state.
class FetchVectorWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env
        self.action_space = self.env.action_space 
        self.observation_space = self.env.observation_space['observation']
      
    def step(self, action):
        observation, rewards, terminated, truncated, info = self.env.step(action)
        observation = observation['observation'] 
        return observation, rewards, terminated, truncated, info 
    
    def reset(self, **kwargs):
        observation, info = self.env.reset(**kwargs)
        observation = observation['observation']
        return observation, info
    
    def render(self):
        return self.env.render()

class EnvRandomDiscreteActions(gym.Wrapper):
    """
    Environment wrapper that creates K random actions and keeps using them. 
    """
    def __init__(self, env, env_rd_num_actions=64, env_rd_action_magnitude=0.5):
        super().__init__(env)
        self.action_space = spaces.Discrete(env_rd_num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        self.action_magnitude = env_rd_action_magnitude
        # Create K random actions
        self.num_actions = env_rd_num_actions
        action_dim = env.action_space.shape[0]  # Get action space dimension
        
        # Generate K actions by randomly sampling from {-env_rd_action_magnitude, 0, env_rd_action_magnitude} for each dimension
        self.actions = []
        for i in range(env_rd_num_actions):
            action = np.random.choice([-self.action_magnitude, 0, self.action_magnitude], size=action_dim)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
    
class UniformActionSetEnv(gym.Wrapper):
    """
    Environment wrapper that creates K uniformly sampled actions from the action space bounds.
    Each action dimension is uniformly sampled between the environment's action limits.
    """
    def __init__(self, env, num_actions=64):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        # Get action space information
        self.num_actions = num_actions
        action_dim = env.action_space.shape[0]
        
        # Get action limits from the environment's action space
        if hasattr(env.action_space, 'low') and hasattr(env.action_space, 'high'):
            action_low = env.action_space.low
            action_high = env.action_space.high
        else:
           raise ValueError("Environment action space must have 'low' and 'high' attributes") 
        
        # Generate K actions by uniformly sampling from action space bounds
        self.actions = []
        for i in range(num_actions):
            action = np.random.uniform(action_low, action_high)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info

    def get_action(self, k):
        """
        Get the continuous action corresponding to action ID k
        """
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
        return self.actions[k]
    
    def get_all_actions(self):
        """
        Get all available actions
        """
        return np.array(self.actions)

class GaussianActionSetEnv(gym.Wrapper):
    """
    Environment wrapper that creates K Gaussian sampled actions from the action space.
    Each action dimension is independently sampled from a Gaussian distribution with 
    specified mean and sigma.
    """
    def __init__(self, env, num_actions=64, mean=None, sigma=None):
        super().__init__(env)
        self.action_space = spaces.Discrete(num_actions)
        self.observation_space = env.observation_space
        self.metadata = env.metadata
        self.steps = 0
        self.obs = None
        
        # Get action space information
        self.num_actions = num_actions
        action_dim = env.action_space.shape[0]
        
        # Get action limits from the environment's action space
        if hasattr(env.action_space, 'low') and hasattr(env.action_space, 'high'):
            action_low = env.action_space.low
            action_high = env.action_space.high
        else:
           raise ValueError("Environment action space must have 'low' and 'high' attributes") 
        
        # Set default mean and sigma if not provided
        if mean is None:
            # Default to center of action space
            self.mean = np.mean((action_low + action_high) / 2.0)  # Single float value
        else:
            self.mean = float(mean)  # Ensure it's a single float
        
        if sigma is None:
            # Default to 1/6 of action range (3-sigma rule covers most of range)
            self.sigma = np.mean((action_high - action_low) / 6.0)  # Single float value
        else:
            self.sigma = float(sigma)  # Ensure it's a single float
        
        # Generate K actions by sampling from Gaussian distribution for each dimension
        self.actions = []
        for i in range(num_actions):
            # Sample each dimension independently from Gaussian
            action = np.random.normal(self.mean, self.sigma, size=action_dim)
            # Clip to action space bounds
            action = np.clip(action, action_low, action_high)
            self.actions.append(action)
        
    def step(self, k):
        """
        Execute step using action ID k
        """
        # Validate action ID
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
            
        self.steps += 1
        # Get the action using the provided ID
        continuous_action = self.actions[k]
        observation, rewards, terminated, truncated, info = self.env.step(continuous_action)
        self.obs = observation
        return observation, rewards, terminated, truncated, info 
        
    def reset(self, **kwargs):
        """
        Reset the environment
        """
        self.steps = 0
        observation, info = self.env.reset(**kwargs)
        self.obs = observation
        return observation, info
    
    def get_action(self, k):
        """
        Get the continuous action corresponding to action ID k
        """
        if k < 0 or k >= self.num_actions:
            raise ValueError(f"Action ID {k} is out of bounds. Must be between 0 and {self.num_actions-1}")
        return self.actions[k]
    
    def get_all_actions(self):
        """
        Get all available actions
        """
        return np.array(self.actions)
    
    def resample_actions(self):
        """
        Resample all actions from the Gaussian distribution
        """
        action_low = self.env.action_space.low
        action_high = self.env.action_space.high
        
        self.actions = []
        for i in range(self.num_actions):
            # Sample each dimension independently from Gaussian
            action = np.random.normal(self.mean, self.sigma, size=action_dim)
            # Clip to action space bounds
            action = np.clip(action, action_low, action_high)
            self.actions.append(action)
            
# this is used for training base agents.
def build_env(args, render_mode="human", camera_id=0):
    logger.debug(f"{YELLOW}Building environment: {args.env}{ENDC}") 
    
    if "Fetch" in args.env:
        env_fn = lambda: FetchVectorWrapper(gym.make(args.env, render_mode=render_mode, width=600, height=600))
    elif args.env == "Humanoid-v4":
       raise 
        
    elif args.env == "Humanoid-v5":
        env_fn = lambda: gym.make(
            args.env, 
            include_cinert_in_observation=False, 
            include_cvel_in_observation=False,
            include_cfrc_ext_in_observation=False, 
            include_qfrc_actuator_in_observation=False,
            exclude_current_positions_from_observation=True,
            render_mode=render_mode,
            height=1600,
            width=1600,
            camera_id=camera_id,
            xml_file = f"{BASE_XML_DIR}/humanoid.xml")

    elif args.env == "Humanoid-v5-rc":
        env_fn = lambda: EnvRandomDiscreteActions(
            gym.make(
                'Humanoid-v5', 
                include_cinert_in_observation=False, 
                include_cvel_in_observation=False,
                include_cfrc_ext_in_observation=False, 
                include_qfrc_actuator_in_observation=False,
                render_mode=render_mode
            ),
            env_rd_action_magnitude=args.env_rd_action_magnitude,
            env_rd_num_actions=args.env_rd_num_actions
        ) 
    elif args.env == "Walker2d-v5-rc":
        env_fn = lambda: EnvRandomDiscreteActions(
            gym.make(
                'Walker2d-v5', 
                render_mode=render_mode
            ),
            env_rd_action_magnitude=args.env_rd_action_magnitude,
            env_rd_num_actions=args.env_rd_num_actions
        ) 
    elif args.env == "ant-rc":
        env_fn = lambda: EnvRandomDiscreteActions(
            gym.make(
                'ant', 
                render_mode=render_mode
            ),
            env_rd_action_magnitude=args.env_rd_action_magnitude,
            env_rd_num_actions=args.env_rd_num_actions
        )   
    
    elif args.env == "Humanoid-v5-ru":
        env_fn = lambda: UniformActionSetEnv(
            gym.make(
                'Humanoid-v5', 
                include_cinert_in_observation=False, 
                include_cvel_in_observation=False,
                include_cfrc_ext_in_observation=False, 
                include_qfrc_actuator_in_observation=False,
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions
        )
    elif args.env == "Walker2d-v5-ru":
        env_fn = lambda: UniformActionSetEnv(
            gym.make(
                'Walker2d-v5', 
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions
        )
    elif args.env == "ant-ru":
        env_fn = lambda: UniformActionSetEnv(
            gym.make(
                'ant', 
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions
        )
    
    elif args.env == "ant-gaussian":
        env_fn = lambda: GaussianActionSetEnv(
            gym.make(
                'ant', 
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions,
            mean=args.env_rd_mean,
            sigma=args.env_rd_std
        ) 
    elif args.env == "Humanoid-v5-gaussian":
        env_fn = lambda: GaussianActionSetEnv(
            gym.make(
                'Humanoid-v5', 
                include_cinert_in_observation=False, 
                include_cvel_in_observation=False,
                include_cfrc_ext_in_observation=False, 
                include_qfrc_actuator_in_observation=False,
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions,
            mean=args.env_rd_mean,
            sigma=args.env_rd_std
        )
    elif args.env == "Walker2d-v5-gaussian":
        env_fn = lambda: GaussianActionSetEnv(
            gym.make(
                'Walker2d-v5', 
                render_mode=render_mode
            ),
            num_actions=args.env_rd_num_actions,
            mean=args.env_rd_mean,
            sigma=args.env_rd_std
        )
    
    
    elif args.env == "Swimmer-v5":
        env_fn = lambda: gym.make(args.env, render_mode=render_mode, camera_id=camera_id, width=600, height=600, xml_file=f"{BASE_XML_DIR}/swimmer.xml")
    else:
        env_fn = lambda: gym.make(args.env, render_mode=render_mode, camera_id=camera_id, width=600, height=600)
    return env_fn  



