from d4rl.locomotion.wrappers import NormalizedBoxEnv
from rlf.envs.env_interface import EnvInterface, register_env_interface
import gym
import numpy as np
from rlf.args import str2bool
import rlf.rl.utils as rutils

import os

from gym import register, utils
from gym.envs.mujoco import mujoco_env
from scipy.optimize import minimize
import torch

class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self):
        mujoco_env.MujocoEnv.__init__(self, "hopper.xml", 4)
        utils.EzPickle.__init__(self)

    def step(self, a):
        posbefore = self.sim.data.qpos[0]
        self.do_simulation(a, self.frame_skip)
        posafter, height, ang = self.sim.data.qpos[0:3]
        alive_bonus = 1.0
        reward = (posafter - posbefore) / self.dt
        reward += alive_bonus
        reward -= 1e-3 * np.square(a).sum()
        s = self.state_vector()
        done = not (
            np.isfinite(s).all()
            and (np.abs(s[2:]) < 100).all()
            and (height > 0.7)
            and (abs(ang) < 0.2)
        )
        ob = self._get_obs()

        
        info = {}
        info["ep_found_goal"] = False
        return ob, reward, done, info

    def _get_obs(self):
        return np.concatenate(
            # [self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
            [self.sim.data.qpos.flat[0:], np.clip(self.sim.data.qvel.flat, -10, 10)] # x pos is included
        )

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
            low=-0.005, high=0.005, size=self.model.nq
        )
        qvel = self.init_qvel + self.np_random.uniform(
            low=-0.005, high=0.005, size=self.model.nv
        )
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model_to_certain_state(self, state):
        qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq)
        qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv)
        
        qpos[1:] = state[:5]
        
        qvel = state[5:]
        self.set_state(qpos, qvel)
        self.prev_qpos = np.copy(self.sim.data.qpos.flat)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.trackbodyid = 2
        self.viewer.cam.distance = self.model.stat.extent * 0.75
        self.viewer.cam.lookat[2] = 1.15
        self.viewer.cam.elevation = -20

class TimeLimit(gym.Wrapper):
    def __init__(self, env, max_episode_steps=None):
        super(TimeLimit, self).__init__(env)
        self._max_episode_steps = max_episode_steps
        self._elapsed_steps = 0

    def step(self, ac):
        observation, reward, done, info = self.env.step(ac)
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True
            info['TimeLimit.truncated'] = True
        return observation, reward, done, info

    def reset(self, **kwargs):
        self._elapsed_steps = 0
        return self.env.reset(**kwargs)

# register(
#     id='MBRLHalfCheetah-v0',
#     entry_point='rl-toolkit.rlf.envs.half_cheetah_interface:HalfCheetahEnv'
# )

class ActionSpaceBoxWrapper(gym.Wrapper):
    def __init__(self, env, ub):
        super().__init__(env)
        self.action_space = gym.spaces.Box(low=-ub, high=ub, shape=env.action_space.shape)

class StateDependentConstraintWrapper(gym.Wrapper):
    def __init__(self, env, con_ub=10):
        super().__init__(env)
        self.current_state = None
        self.con_ub = con_ub

    def step_before_projection(self, action, state):
        new_action = self.H_M_Projection(action, state)
        return super().step(new_action)
    

    def H_M_Projection(self, action, state):
        """
        Alternative implementation of HC_O_Projection without Gurobi.
        Solves a quadratic optimization problem with SciPy.
        """

        # Extract parameters from action and state
        neta = action[:3]  # Target actions
        # w = state[8:11]   # Weights from the state
        w = state[9:12]   # Weights from the state

        # Objective function: Minimize (a_i - neta_i)^2
        def objective(a):
            return np.sum((a - neta) ** 2)

        # Constraints: |a_i * w_i| summed must be <= 1.0 (similar to v constraint)
        def constraint(a):
            abs_u = np.abs(a * w)  # Element-wise product and absolute values
            return self.con_ub - np.sum(abs_u)  # <= 0 -> expressed as 1 - sum(abs_u) >= 0

        # Bounds: Each a_i in [-1, 1]
        bounds = [(self.action_space.low[i], self.action_space.high[i]) for i in range(3)]

        # Initial guess for optimization
        a0 = np.zeros(3)

        # Define constraints for SciPy
        constraints = ({'type': 'ineq', 'fun': constraint})

        # Solve the optimization problem
        result = minimize(objective, a0, bounds=bounds, constraints=constraints, method='SLSQP')

        # Check if optimization succeeded
        if not result.success:
            print("Optimization failed:", result.message)
            return np.zeros(3)  # Return zero actions if failed

        # return the optimized actions as a numpy array
        return result.x
        # # Return the optimized actions as a torch tensor
        # return torch.tensor(result.x, dtype=torch.float32)


    def step(self, action):
        action = self.H_M_Projection(action, self.current_state)
        observation, reward, done, info = super().step(action)
        self.current_state = observation
        info['real_action'] = action
        return observation, reward, done, info
    
    def reset(self):
        self.current_state = super().reset()
        return self.current_state





class Hopper(EnvInterface):
    def create_from_id(self, env_id):
        # env = NormalizedBoxEnv(Hopper())
        # # wrap with a TimeLimit wrapper
        # env = TimeLimit(env, max_episode_steps=1000)
        env = HopperEnv()
        env = TimeLimit(env, max_episode_steps=self.args.hp_timelimit)
        env = NormalizedBoxEnv(env)
        if self.args.hp_constrained:
            env = ActionSpaceBoxWrapper(env, ub=self.args.hp_ub)
            env = StateDependentConstraintWrapper(env)
        # if self.args.gw_mode == 'flat':
        #     env = FlatGrid(env, self.args.gw_card_dirs)
        # elif self.args.gw_mode == 'img':
        #     env = FullFlatGrid(FullyObsWrapper(env), self.args.gw_card_dirs)
        # else:
        #     raise ValueError()

        # if self.args.gw_goal_info:
        #     env = DirectionObsWrapper(env)
        return env

    def get_add_args(self, parser):
        parser.add_argument('--hp-constrained', type=str2bool, default=False, help="")
        parser.add_argument('--hp-ub', type=float, default=0.5, help="")
        parser.add_argument('--hp-timelimit', type=int, default=1000, help="")
        # parser.add_argument('--gw-card-dirs', action='store_true',
        #         default=False)
        # parser.add_argument('--gw-goal-info', action='store_true',
        #         default=False)

register_env_interface("^MBRLHopper", Hopper)






