import numpy as np
import gymnasium as gym
from gymnasium import Env
from feature4irl.util.feature_gen import select_feat_extractor
from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward

##################################################
#      TransformRewardLearnedCont
###################################################
class TransformRewardLearnedCont(gym.RewardWrapper):
    """Transform the reward via an arbitrary function."""
    def __init__(self, env: gym.Env, alpha=None, configs=None):
        """Initialize the :class:`TransformReward` wrapper with an environment
        and reward transform function :attr:`f`.
        Args:
            env: The environment to apply the wrapper
            f: A function that transforms the reward
        """
        super().__init__(env)

        self.alpha = alpha
        self.configs = configs

    def reward(self, reward):
        """Transforms the reward using callable :attr:`f`.
        Args:
            reward: The reward to transform
        Returns:
            The transformed reward
        """
        state = self.get_wrapper_attr('temp_state')
        env_name = self.spec.id
        feature_expectations = select_feat_extractor(env_name,
                                                     state,
                                                     cfg=self.configs)
        reward = feature_expectations.dot(self.alpha)
        return reward


class BaseEnvWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
    """
    Base wrapper to handle observation storage and conditional reward transformation.
    """
    def __init__(self, env: gym.Env, reward_path=None, scaler_path=None, configs=None):
        # Initialize base environment
        super().__init__(StoreObservation(env))
        if reward_path and reward_path != 'None' and scaler_path and configs:
            # print('Applying reward transformation')
            alpha = np.load(reward_path + '.npy')
            self.env = TransformRewardLearnedCont(self.env, alpha, configs)
            # self.env = NormalizeObservation(self.env)
            # self.env = NormalizeReward(self.env)
            
        elif reward_path is None:
            raise ValueError('reward path cannot be None if specified')
        

class StoreObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        """Resizes image observations to shape given by :attr:`shape`.
        Args:
            env: The environment to apply the wrapper
            shape: The shape of the resized observations
        """
        super().__init__(env)
    def observation(self, observation):
        self.temp_state = observation
        return observation
    
    
class CropObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        """Resizes image observations to shape given by :attr:`shape`.
        Args:
            env: The environment to apply the wrapper
            shape: The shape of the resized observations
        """
        env.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
        )
        super().__init__(env)
    def observation(self, observation):
        self.temp_state = observation[1:]
        return self.temp_state


class StoreAction(gym.ActionWrapper):
    def __init__(self, env: gym.Env):
        """Resizes image observations to shape given by :attr:`shape`.
        Args:
            env: The environment to apply the wrapper
            shape: The shape of the resized observations
        """
        super().__init__(env)

    def action(self, action):
        self.temp_action = action
        return action


class NoEarlyTerminationWrapper(gym.Wrapper):
    def __init__(self, env: Env, max_steps: int):
        super(NoEarlyTerminationWrapper, self).__init__(env)
        self.max_steps = max_steps
        self.current_step = 0

    def reset(self, **kwargs):
        self.current_step = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.current_step += 1
        # Check if we have reached the desired number of steps
        done = self.current_step >= self.max_steps
        # Otherwise, continue running until the step limit is reached
        if done:
            terminated = True
        else:
            terminated = False
            truncated = False
        return observation, reward, terminated, truncated, info


class NoRenderWrapper(gym.Wrapper):
    def __init__(self, env: Env):
        super(NoRenderWrapper, self).__init__(env)
    
    def render(self, *args, **kwargs):
        # Override the render method to do nothing
        return None

class WalkerWrapper(BaseEnvWrapper):
    """
    Specific wrapper for the CartPole environment.
    """
    def __init__(self, env, reward_path=None, env_name=None, scaler_path=None, configs=None):
        #env = NoEarlyTerminationWrapperX(env)
        super().__init__(env, reward_path, scaler_path, configs)
        
class NoEarlyTerminationWrapperX(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        base_env = env.unwrapped  # Access the base environment without any wrappers
        if hasattr(base_env, '_terminate_when_unhealthy'):
            base_env._terminate_when_unhealthy = False
        else:
            print("Warning: The environment does not support custom termination behavior.")

class HopperWrapper(BaseEnvWrapper):
    """
    Specific wrapper for the CartPole environment.
    """
    def __init__(self, env, reward_path=None, env_name=None, scaler_path=None, configs=None):
        env = NoEarlyTerminationWrapperX(env)
        super().__init__(env, reward_path, scaler_path, configs)

class AntWrapper(BaseEnvWrapper):
    """
    Specific wrapper for the CartPole environment.
    """
    def __init__(self, env, reward_path=None, env_name=None, scaler_path=None, configs=None):
        env = NoEarlyTerminationWrapperX(env)
        super().__init__(env, reward_path, scaler_path, configs)


class CheetahWrapper(BaseEnvWrapper):
    """
    Specific wrapper for the CartPole environment.
    """
    def __init__(self, env, reward_path=None, env_name=None, scaler_path=None, configs=None):
        super().__init__(env, reward_path, scaler_path, configs)
        # Apply NoEarlyTerminationWrapper to the environment
        # self.env = NoEarlyTerminationWrapper(self.env, max_steps=1000)
        # self.env = NoRenderWrapper(self.env)