"""
Licensed Materials - Property of IBM
Restricted Materials of IBM
20190891
© Copyright IBM Corp. 2021 All Rights Reserved.
"""
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from os import path
from ibmfl.data.env_spec import EnvHandler

class PendulumEnv(EnvHandler):
    """
        Github : https://github.com/openai/gym/blob/master/gym/envs/classic_control/pendulum.py
        Description:
            Try to keep a frictionless pendulum standing up.
            The inverted pendulum swingup problem is a classic problem in the control literature. In this version of the
             problem, the pendulum starts in a random position, and the goal is to swing it up so it stays upright
        Source:
           Unknown
        Observation:
            Type: Box(3)
            Num	Observation	 Min	Max
            0	cos(theta)	-1.0	1.0
            1	sin(theta)	-1.0	1.0
            2	theta dot	-8.0	8.0

        Actions:
            Type: Box(1)
            Num	Action	        Min	    Max
            0	Joint effort	-2.0	2.0

        Reward:
            The precise equation for reward:

            -(theta^2 + 0.1*theta_dt^2 + 0.001*action^2)
            Theta is normalized between -pi and pi. Therefore, the lowest reward is
            -(pi^2 + 0.1*8^2 + 0.001*2^2) = -16.2736044,
             and the highest reward is 0. In essence, the goal is to remain at zero angle (vertical), with the least
             rotational velocity, and the least effort.
        Starting State:
            Random angle from -pi to pi, and random velocity between -1 and 1
        Episode Termination:
            There is no specified termination. Adding a maximum number of steps might be a good idea.

            NOTE: Your environment object could be wrapped by the TimeLimit wrapper, if created using the "gym.make"
            method. In that case it will terminate after 200 steps.
    """
    metadata = {
        'render.modes' : ['human', 'rgb_array'],
        'video.frames_per_second' : 30
    }

    def __init__(self, data=None, env_config=None, g=10.0):
        self.max_speed=8
        self.max_torque=2.
        self.dt=.05
        self.g = g
        self.m = 1.
        self.l = 1.
        self.viewer = None

        high = np.array([1., 1., self.max_speed])
        self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)

        self.seed()
        self._max_episode_steps = 200
        self._elapsed_steps = None

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self,u):
        th, thdot = self.state # th := theta

        g = self.g
        m = self.m
        l = self.l
        dt = self.dt

        u = np.clip(u, -self.max_torque, self.max_torque)[0]
        self.last_u = u # for rendering
        costs = angle_normalize(th)**2 + .1*thdot**2 + .001*(u**2)

        newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt
        newth = th + newthdot*dt
        newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) #pylint: disable=E1111

        self.state = np.array([newth, newthdot])

        done = False
        self._elapsed_steps += 1
        if self._elapsed_steps >= self._max_episode_steps:
            done = True

        return self._get_obs(), -costs, done, {}

    def reset(self):
        high = np.array([np.pi, 1])
        self.state = self.np_random.uniform(low=-high, high=high)
        self.last_u = None
        self._elapsed_steps = 0
        return self._get_obs()

    def _get_obs(self):
        theta, thetadot = self.state
        return np.array([np.cos(theta), np.sin(theta), thetadot])

    def render(self, mode='human'):

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.Viewer(500,500)
            self.viewer.set_bounds(-2.2,2.2,-2.2,2.2)
            rod = rendering.make_capsule(1, .2)
            rod.set_color(.8, .3, .3)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)
            axle = rendering.make_circle(.05)
            axle.set_color(0,0,0)
            self.viewer.add_geom(axle)
            fname = path.join(path.dirname(__file__), "assets/clockwise.png")
            self.img = rendering.Image(fname, 1., 1.)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)

        self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi/2)
        if self.last_u:
            self.imgtrans.scale = (-self.last_u/2, np.abs(self.last_u)/2)

        return self.viewer.render(return_rgb_array = mode=='rgb_array')

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

def angle_normalize(x):
    return (((x+np.pi) % (2*np.pi)) - np.pi)