from os import path

import gym
import numpy as np
from gym import spaces

class SimplePendulumEnv(gym.Env):
    """An inverted pendulum environment.

    Characteristica:
        - random initial state
    """

    def __init__(self, safe_region):

        self._safe_region = safe_region

        self.l = 1.
        self.m = 1.
        self.g = 9.81
        self.dt = .05

        max_torque = 30
        self.action_space = spaces.Box(
            low=-max_torque,
            high=max_torque,
            shape=(1,),
            dtype=np.float32
        )

        obs_high = np.array([1.0, 1.0, np.inf], dtype=np.float32)
        self.observation_space = spaces.Box(low=-obs_high, high=obs_high, dtype=np.float32)

        self.state = None
        self.viewer = None

    def reset(self, **kwargs):
        self._last_action = None
        self._is_safety_violated = None
        self.state = np.asarray(self._safe_region.sample())
        return self._get_obs()

    def step(self, action):
        theta, thdot = self.state
        self._last_action = action
        reward = self._get_reward(*self.state, torque=action)
        self.state = self.dynamics(theta, thdot, action)
        return self._get_obs(), reward, False, {}

    def _get_obs(self):
        theta, thdot = self.state
        return np.array([np.cos(theta), np.sin(theta), thdot], dtype=np.float32)

    def dynamics(self, theta, thdot, torque):

        # sb3.common.distributions
        if isinstance(torque, np.ndarray):
            torque = torque.item()

        new_thdot = thdot + self.dt * ((self.g / self.l) * np.sin(theta) + 1. / (self.m * self.l ** 2) * torque)
        new_theta = theta + self.dt * new_thdot
        return np.array([new_theta, new_thdot])

    def angle_normalize(self, x):
        return ((x + np.pi) % (2 * np.pi)) - np.pi

    def _get_reward(self, theta, thdot, torque):
        rew = -(self.angle_normalize(theta) ** 2 + 0.1 * thdot ** 2 + 0.001 * (torque ** 2))
        return float(rew)

    def render(self, mode="human", **kwargs):

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)

            rod = rendering.make_capsule(1, .035)
            rod.set_color(0, 0, 0)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)

            self.mass = rendering.make_circle(.15)
            self.mass.set_color(0/255, 92/255, 171/255)
            self.mass_transform = rendering.Transform()
            self.mass.add_attr(self.mass_transform)
            self.viewer.add_geom(self.mass)

            axle = rendering.make_circle(.035)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)

            self.img_black = rendering.Image(path.join(path.dirname(__file__), "assets/clockwise.png"), 1., 1.)
            self.imgtrans_black = rendering.Transform()
            self.img_black.add_attr(self.imgtrans_black)
            self.imgtrans_black.scale = (0., 0.)

        if self._last_action is not None:
            self.viewer.add_onetime(self.img_black)
            self.imgtrans_black.scale = (-self._last_action/8, -abs(self._last_action)/8)

        theta_trans = -self.state[0] + np.pi / 2
        self.pole_transform.set_rotation(theta_trans)
        self.mass_transform.set_translation(np.cos(theta_trans), np.sin(theta_trans))

        if not self._is_safety_violated and self.state not in self._safe_region:
            self.mass.set_color(227/255, 27/255, 35/255)
            self._is_safety_violated = True

        return self.viewer.render(return_rgb_array=mode == "rgb_array")

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None
