import numpy as np
from sklearn.utils.validation import check_random_state
from gym.envs.classic_control import PendulumEnv


def angle_normalize(x):
    return (((x+np.pi) % (2*np.pi)) - np.pi)

def safety_function(observations):
    """
    This function returns 1 if the state is unsafe and 0 if safe. It has to be able to handle the same stuff
    as the one used by the model env caause when testing with the new func, this new func will be used for both
    """
    angle = np.arctan2(observations[:, 1], observations[:, 0])
    unsafe_min_angle = np.pi * (20. / 180)  # 20 degree converted in rad
    unsafe_max_angle = np.pi * (30. / 180)  # 30 degree converted in rad

    return np.logical_and(unsafe_min_angle < angle, angle < unsafe_max_angle).astype(np.float)

class Env(PendulumEnv):
    """Modified Open AI Gym Pendulum.

    The reward is one defined in the benchmark paper.
    We also provide a seed method accepting instances of
    numpy.random.RandomState.
    """
    def __init__(self, max_episode_steps=200):
        self.max_episode_steps = max_episode_steps
        super(Env, self).__init__()

        self.unsafe_min_angle = np.pi * (20. / 180) # 20 degree converted in rad
        self.unsafe_max_angle = np.pi * (30. / 180) # 30 degree converted in rad
        self.safety_func = safety_function


    def seed(self, seed=None):
        """Same as parent method but passing a RandomState instance is allowed.
        """
        self.np_random = check_random_state(seed)
        return [seed]

    def reset(self):
        """Same as parent method but returns states instead of observations."""
        observations = super(Env, self).reset()
        self._elapsed_steps = 0
        return observations

    def step(self, action):
        """Same as parent method but using benchmark paper reward"""
        action = np.array([action])

        observation, reward, _, info = super(Env, self).step(action)
        self._elapsed_steps += 1
        # using >= in case we need the info when planning with the real env
        done = (self._elapsed_steps >= self.max_episode_steps)
        safety_cost = self.safety_func(np.array([observation]))[0]
        info['cost'] = safety_cost
        return observation, reward, done, info

    def set_state(self, state_dict):
        """Set state of the environment."""
        self._elapsed_steps = state_dict['_elapsed_steps']
        self.state = np.r_[state_dict['qpos'], state_dict['qvel']]

    def get_state(self):
        """Get state of the environement."""
        state = self.state
        state_dict = {
            'qpos': state[:2],
            'qvel': state[2:],
            '_elapsed_steps': self._elapsed_steps,
        }
        return state_dict

    def get_numpy_state(self):
        """Get the state numpy array from the environment."""
        state_dict = self.get_state()
        return np.r_[state_dict['qpos'], state_dict['qvel']].squeeze()
    
    def set_numpy_state(self, numpy_state):
        """Set the state from a numpy array.

        Note that the _elapsed_steps attribute is reset to 0.
        """
        self.state = numpy_state
        self._elapsed_steps = 0

    def render(self, mode='human'):
        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
            self.pole = rendering.make_capsule(1, .2)

            self.pole.set_color(0,255,0)
            self.pole_transform = rendering.Transform()
            self.pole_color = self.pole.attrs[0]
            self.pole.add_attr(self.pole_transform)
            self.viewer.add_geom(self.pole)
            axle = rendering.make_circle(.05)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)
            unsafe_area_coord = [(np.cos(self.unsafe_min_angle + np.pi / 2), np.sin(self.unsafe_min_angle + np.pi / 2)),
                           (np.cos(self.unsafe_max_angle + np.pi / 2), np.sin(self.unsafe_max_angle + np.pi / 2)),
                            (0,0)]
            unsafe_area = self.viewer.draw_polygon(unsafe_area_coord)
            unsafe_area._color.vec4 = (255, 0, 0, 0.4)
            self.viewer.add_geom(unsafe_area)


        # self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
        safety_cost = self.safety_func(self.state)
        if safety_cost:
            self.pole.set_color(255, 0, 0)
        else:
            self.pole.set_color(0, 255, 0)

        if self.last_u:
            self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')
