from d4rl.locomotion.wrappers import NormalizedBoxEnv
from rlf.envs.env_interface import EnvInterface, register_env_interface
import gym
import numpy as np
from rlf.args import str2bool
import rlf.rl.utils as rutils

import os

from gym import register, utils
from gym.envs.mujoco import mujoco_env
from scipy.optimize import minimize
import torch


class ActionSpaceBoxWrapper(gym.Wrapper):
    def __init__(self, env, ub):
        super().__init__(env)
        self.action_space = gym.spaces.Box(low=-ub, high=ub, shape=env.action_space.shape)
    
    def step(self, action):
        action = np.clip(action, self.action_space.low, self.action_space.high)
        observation, reward, done, info = super().step(action)
        info['real_action'] = action
        return observation, reward, done, info



class FetchPushActionSpaceBoxWrapper(gym.Wrapper):
    def __init__(self, env, ub):
        super(FetchPushActionSpaceBoxWrapper, self).__init__(env)
        self.action_space = gym.spaces.Box(low=-ub, high=ub, shape=env.action_space.shape)
        self.action_ub = ub
    
    def step(self, action):
        # find the dim having highest value, and rescale all dim to be within [-1, 1]
        max_dim = np.argmax(np.abs(action))
        if np.abs(action[max_dim]) > self.action_ub:
            action = action / np.abs(action[max_dim]) * self.action_ub
        observation, reward, done, info = super().step(action)
        info['real_action'] = action
        return observation, reward, done, info


class FetchPickActionSpaceBoxWrapper(gym.Wrapper):
    def __init__(self, env, ub):
        super(FetchPickActionSpaceBoxWrapper, self).__init__(env)
        self.action_space = gym.spaces.Box(low=-ub, high=ub, shape=env.action_space.shape)
        self.action_ub = ub
    
    def step(self, action):
        # find the dim having highest value, and rescale all dim(except the last one) to be within [-1, 1]
        new_action = action.copy()
        max_dim = np.argmax(np.abs(action[:-1]))
        if np.abs(action[max_dim]) > self.action_ub:
            new_action[:-1] = action[:-1] / np.abs(action[max_dim]) * self.action_ub
        observation, reward, done, info = super().step(new_action)
        info['real_action'] = new_action
        return observation, reward, done, info

class ActionDimBlockWrapper(gym.Wrapper):
    def __init__(self, env, dim_filter=3):
        super().__init__(env)
        self.dim_filter = dim_filter
    
    def action_block_func(self, action):
        # only dim_filter actions are used, others are set to 0. Save the largest ones
        action = np.array(action)
        assert len(action) == self.action_space.shape[0]
        sorted_indices = np.argsort(np.abs(action))
        action[sorted_indices[:self.dim_filter]] = 0
        return action

    def step(self, action):
        action = self.action_block_func(action)
        observation, reward, done, info = super().step(action)
        info['real_action'] = action
        return observation, reward, done, info
    
class EnvObsWrapper(gym.Wrapper):
    def __init__(self, env):
        super(EnvObsWrapper, self).__init__(env)
        self.env = env
        self.key = None
        if 'observation' in env.observation_space.spaces.keys():
            self.observation_space = self.env.observation_space['observation']
            self.key = 'observation'
        elif 'image' in env.observation_space.spaces.keys():
            self.observation_space = self.env.observation_space['image']
            self.key = 'image'
        else:
            self.observation_space = self.env.observation_space


    def get_raw_obs(self, obs):
        if self.key is not None:
            return obs[self.key]
        else:
            return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self.get_raw_obs(obs), reward, done, info

    def reset(self):
        obs = self.env.reset()
        return self.get_raw_obs(obs)
