from d4rl.locomotion.wrappers import NormalizedBoxEnv
from rlf.envs.env_interface import EnvInterface
import gym
import numpy as np
from rlf.args import str2bool
import rlf.rl.utils as rutils

import os

from gym import register, utils
from gym.envs.mujoco import mujoco_env
from scipy.optimize import minimize
import torch
from demo_collection.utils.constrain_wrapper import ActionSpaceBoxWrapper

from d4rl import offline_env

class StateDependentConstraintWrapper(gym.Wrapper):
    def __init__(self, env, con_ub=0.5):
        super().__init__(env)
        self.current_state = None
        self.con_ub = con_ub

    def step_before_projection(self, action, state):
        new_action = self.M_O_Projection(action, state)
        return super().step(new_action)
    

    def M_O_Projection(self, action, state):
        """
        Alternative implementation of HC_O_Projection without Gurobi.
        Solves a quadratic optimization problem with SciPy.
        """

        # Extract parameters from action and state
        neta = action[:2]  # Target actions
        w = state[2:4]   # Weights from the state

        # Objective function: Minimize (a_i - neta_i)^2
        def objective(a):
            return np.sum((a - neta) ** 2)

        # Constraints: |a_i * w_i| summed must be <= 1.0 (similar to v constraint)
        def constraint(a):
            abs_u = np.abs(a * w)  # Element-wise product and absolute values
            return self.con_ub - np.sum(abs_u)  # <= 0 -> expressed as 1 - sum(abs_u) >= 0

        # Bounds: Each a_i in [-1, 1]
        # bounds = [(-1, 1)] * 3
        bounds = [(self.action_space.low[i], self.action_space.high[i]) for i in range(2)]

        # Initial guess for optimization
        a0 = np.zeros(2)

        # Define constraints for SciPy
        constraints = ({'type': 'ineq', 'fun': constraint})

        # Solve the optimization problem
        result = minimize(objective, a0, bounds=bounds, constraints=constraints, method='SLSQP')

        # Check if optimization succeeded
        if not result.success:
            print("Optimization failed:", result.message)
            return np.zeros(2)  # Return zero actions if failed

        # return the optimized actions as a numpy array
        return result.x
        # # Return the optimized actions as a torch tensor
        # return torch.tensor(result.x, dtype=torch.float32)


    def step(self, action):
        action = self.M_O_Projection(action, self.current_state)
        observation, reward, done, info = super().step(action)
        self.current_state = observation
        info['real_action'] = action
        return observation, reward, done, info
    
    def reset(self):
        self.current_state = super().reset()
        return self.current_state

class GoalFoundInfoWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)

    def step(self, action):
        observation, reward, done, info = super().step(action)
        info['ep_found_goal'] = info['goal_achieved']
        return observation, reward, done, info




class Maze2d(EnvInterface):
    def create_from_id(self, env_id):
        if self.args.mz_reward_type == 'sparse':
            env = gym.make('maze2d-medium-v1')
        elif self.args.mz_reward_type == 'dense':
            env = gym.make('maze2d-medium-dense-v1')
        else:
            raise ValueError()
        # env = TimeLimit(env, max_episode_steps=1000)
        env = GoalFoundInfoWrapper(env)
        env = NormalizedBoxEnv(env)
        if self.args.mz_box_constrained:
            env = ActionSpaceBoxWrapper(env, ub=self.args.mz_ub)
        if self.args.mz_safe_constrained:
            env = StateDependentConstraintWrapper(env)
        # if self.args.gw_mode == 'flat':
        #     env = FlatGrid(env, self.args.gw_card_dirs)
        # elif self.args.gw_mode == 'img':
        #     env = FullFlatGrid(FullyObsWrapper(env), self.args.gw_card_dirs)
        # else:
        #     raise ValueError()

        # if self.args.gw_goal_info:
        #     env = DirectionObsWrapper(env)
        return env

    def get_add_args(self, parser):
        parser.add_argument('--mz-reward-type', type=str, default='dense', help="dense or sparse")
        parser.add_argument('--mz-box-constrained', type=str2bool, default=False, help="")
        parser.add_argument('--mz-ub', type=float, default=0.1, help="")
        parser.add_argument('--mz-safe-constrained', type=str2bool, default=False, help="")
        # parser.add_argument('--gw-card-dirs', action='store_true',
        #         default=False)
        # parser.add_argument('--gw-goal-info', action='store_true',
        #         default=False)

MEDIUM_MAZE = (
    "########\\"
    + "#OO##OO#\\"
    + "#OO#OOO#\\"
    + "##OOO###\\"
    + "#OO#OOO#\\"
    + "#O#OO#O#\\"
    + "#OOO#OG#\\"
    + "########"
)

# register(
#     id='maze2d-medium-v1',
#     entry_point='d4rl.pointmaze:MazeEnv',
#     max_episode_steps=600,
#     kwargs={
#         'maze_spec':MEDIUM_MAZE,
#         'reward_type':'sparse',
#         'reset_target': True,
#         'ref_min_score': 13.13,
#         'ref_max_score': 277.39,
#         'dataset_url':'http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5'
#     }
# )

# register_env_interface("^MBRLMaze2d", Maze2d)






