from d4rl.locomotion.wrappers import NormalizedBoxEnv
from rlf.envs.env_interface import EnvInterface, register_env_interface
import gym
import numpy as np
from rlf.args import str2bool
import rlf.rl.utils as rutils

import os

from gym import register, utils
from gym.envs.mujoco import mujoco_env
from scipy.optimize import minimize
import torch

class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self):
        dir_path = os.path.dirname(os.path.realpath(__file__))
        mujoco_env.MujocoEnv.__init__(self, 'half_cheetah.xml', 5)
        utils.EzPickle.__init__(self)
        

    def step(self, action):
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        xposafter = self.sim.data.qpos[0]

        ob = self._get_obs()
        reward_ctrl = -0.1 * np.square(action).sum()
        reward_run = (xposafter - xposbefore) / self.dt
        reward = reward_run + reward_ctrl

        done = False

        info = {}
        info["ep_found_goal"] = False
        return ob, reward, done, info

    def _get_obs(self):
        position = self.sim.data.qpos.flat.copy()
        velocity = self.sim.data.qvel.flat.copy()

        observation = np.concatenate((position[1:], velocity)).ravel()
        return observation

    def reset_model(self):
        qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq)
        qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv)
        
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        return self._get_obs()

    def reset_model_to_certain_state(self, state):
        qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq)
        qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv)
        
        qpos[1:] = state[:8]
        
        qvel[:] = state[8:]
        
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5
        self.viewer.cam.elevation = -55

class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

    def step(self, ac):
        observation, reward, done, info = self.env.step(ac)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info['TimeLimit.truncated'] = True
        return observation, reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self.env.reset(**kwargs)

# register(
#     id='MBRLHalfCheetah-v0',
#     entry_point='rl-toolkit.rlf.envs.half_cheetah_interface:HalfCheetahEnv'
# )

class ActionSpaceBoxWrapper(gym.Wrapper):
    def __init__(self, env, ub):
        super().__init__(env)
        self.action_space = gym.spaces.Box(low=-ub, high=ub, shape=env.action_space.shape)
    
    def step(self, action):
        # Clip the action to ensure it respects the new bounds
        clipped_action = np.clip(action, self.action_space.low, self.action_space.high)
        # Pass the clipped action to the environment's step function
        obs, reward, done, info = self.env.step(clipped_action)
        info['real_action'] = clipped_action
        return obs, reward, done, info
        

class StateDependentConstraintWrapper(gym.Wrapper):
    def __init__(self, env, con_ub=10):
        super().__init__(env)
        self.current_state = None
        self.con_ub = con_ub

    def step_before_projection(self, action, state):
        new_action = self.HC_O_Projection(action, state)
        return super().step(new_action)
    

    def HC_O_Projection(self, action, state):
        """
        Alternative implementation of HC_O_Projection without Gurobi.
        Solves a quadratic optimization problem with SciPy.
        """

        # Extract parameters from action and state
        neta = action[:6]  # Target actions
        w = state[11:17]   # Weights from the state

        # Objective function: Minimize (a_i - neta_i)^2
        def objective(a):
            return np.sum((a - neta) ** 2)

        # Constraints: |a_i * w_i| summed must be <= 1.0 (similar to v constraint)
        def constraint(a):
            abs_u = np.abs(a * w)  # Element-wise product and absolute values
            return self.con_ub - np.sum(abs_u)  # <= 0 -> expressed as 1 - sum(abs_u) >= 0

        # Bounds: Each a_i in [-1, 1]
        bounds = [(self.action_space.low[i], self.action_space.high[i]) for i in range(6)]

        # Initial guess for optimization
        a0 = np.zeros(6)

        # Define constraints for SciPy
        constraints = ({'type': 'ineq', 'fun': constraint})

        # Solve the optimization problem
        result = minimize(objective, a0, bounds=bounds, constraints=constraints, method='SLSQP')

        # Check if optimization succeeded
        if not result.success:
            print("Optimization failed:", result.message)
            return np.zeros(6)  # Return zero actions if failed

        # return the optimized actions as a numpy array
        return result.x
        # # Return the optimized actions as a torch tensor
        # return torch.tensor(result.x, dtype=torch.float32)


    def step(self, action):
        action = self.HC_O_Projection(action, self.current_state)
        observation, reward, done, info = super().step(action)
        self.current_state = observation
        info['real_action'] = action
        return observation, reward, done, info
    
    def reset(self):
        self.current_state = super().reset()
        return self.current_state





class HalfCheetah(EnvInterface):
    def create_from_id(self, env_id):
        # env = NormalizedBoxEnv(HalfCheetahEnv())
        # # wrap with a TimeLimit wrapper
        # env = TimeLimit(env, max_episode_steps=1000)
        env = HalfCheetahEnv()
        env = TimeLimit(env, max_episode_steps=self.args.hf_timelimit)
        env = NormalizedBoxEnv(env)
        if self.args.hf_constrained:
            env = ActionSpaceBoxWrapper(env, ub=self.args.hf_ub)
            env = StateDependentConstraintWrapper(env)
        # if self.args.gw_mode == 'flat':
        #     env = FlatGrid(env, self.args.gw_card_dirs)
        # elif self.args.gw_mode == 'img':
        #     env = FullFlatGrid(FullyObsWrapper(env), self.args.gw_card_dirs)
        # else:
        #     raise ValueError()

        # if self.args.gw_goal_info:
        #     env = DirectionObsWrapper(env)
        return env

    def get_add_args(self, parser):
        parser.add_argument('--hf-constrained', type=str2bool, default=False, help="")
        parser.add_argument('--hf-ub', type=float, default=0.5, help="")
        parser.add_argument('--hf-timelimit', type=int, default=1000, help="")
        # parser.add_argument('--gw-card-dirs', action='store_true',
        #         default=False)
        # parser.add_argument('--gw-goal-info', action='store_true',
        #         default=False)

# register_env_interface("^MBRLHalfCheetah", HalfCheetah)






