# Modified from OpenAI gym Pendulum-v0 task
# https://github.com/openai/gym/blob/master/gym/envs/classic_control/pendulum.py
# https://github.com/openai/gym/blob/master/gym/envs/classic_control/rendering.py

import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from os import path
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

class MassSpringDamperEnv(gym.Env):
    metadata = {
        'render.modes' : ['human', 'rgb_array'],
        'video.frames_per_second' : 30
    }

    def __init__(self, g=10.0):
        self.max_speed=100.
        self.max_torque=10.
        self.dt=.05
        self.g = g
        self.viewer = None

        high = np.array([1., 1., self.max_speed])
        self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)

        self.seed()

        self.L = 1
        self.mu = 0.1

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def dynamics(self, t, w, TrackExci):

        # Get the state
        x1, x2, x3, x4 = w
        # Get the parameter of the system
        mBG, mCB, cPS, dPS, cSS, dSS = self.p
        # Get the track excitation
        u, up = TrackExci

        # State space definition
        # ----------------------------------------------    
        #  x3 ... zs_CB         ... carbody displacement
        #  x4 ... zs_CB_dot     ... carbody velocity
        #  x1 ... zp_BG1        ... bogieframe1 displacement
        #  x2 ... zp_BG1_dot    ... bogieframe1 velocity
        # ----------------------------------------------
        dxdt = [\
                x2,\
                ( dSS*(x4-x2) + cSS*(x3-x1) - dPS*(x2-up) - cPS*(x1-u) ) / mBG,\
                x4,\
                (-dSS*(x4-x2) - cSS*(x3-x1)) / mCB\
               ]

        dxdt = np.array(dxdt)

        return dxdt

    def step(self,u):

        dt = self.dt

        self.last_u = u # for rendering
        costs = 0

        ivp = solve_ivp(fun=lambda t, y:self.dynamics(t, y, u), t_span=[0, self.dt], y0=self.state)
        self.state = ivp.y[:, -1]

        # newthdot = thdot + (-3*g/(2*l) * np.sin(th + np.pi) + 3./(m*l**2)*u) * dt
        # newth = th + newthdot*dt
        # newthdot = np.clip(newthdot, -self.max_speed, self.max_speed) #pylint: disable=E1111

        # self.state = np.array([newth, newthdot])
        return self._get_obs(), -costs, False, {}

    def reset(self,p,dt,init_state,scale,primary_offset,secondary_offset):
        
        # Set the initial state
        self.state = self.np_random.uniform(low=init_state[0], high=init_state[1])
        # Set the system parameters
        self.p = p
        # Set the time interval between each point
        self.dt = dt
        # Set the scale value
        self.scale = scale
        self.primary_offset = primary_offset
        self.secondary_offset = secondary_offset

        return self._get_obs()

    def _get_obs(self):
        x1, x2, x3, x4 = self.state
        return np.array([x1, x2, x3, x4,\
                         self.p[0],\
                         self.p[1],\
                         self.p[2],\
                         self.p[3],\
                         self.p[4],\
                         self.p[5]])

    def render(self, mode='human'):
        from myenv import rendering

        scale = self.scale #default: 10->30
        bottom = -1.5 #default: -1
        primary_offset   = self.primary_offset #default: -0.1
        pos_p = self.state[0]*scale + primary_offset
        secondary_offset = self.secondary_offset #default:  1.0
        pos_s = self.state[2]*scale + secondary_offset

        gap_s = pos_s - pos_p
        gap_s = gap_s / 4.0

        gap_p = pos_p - (bottom)
        gap_p = gap_p / 4.0

        s_L = 0.4

        wheel_size = 0.3

        # y location of the spring
        y_spring = -0.2
        y_damper =  0.2

        damper_L = 0.4

        if self.viewer is None:
            from myenv import rendering
            self.viewer = rendering.Viewer(32,32)
            self.viewer.set_bounds(-2.2,2.2,-2.2,2.2)

            # Add the primary stage
            rod_1 = rendering.make_capsule(1, .2)
            rod_1.set_color(0, 0, 1)#Blue
            self.pole_transform_1 = rendering.Transform()
            rod_1.add_attr(self.pole_transform_1)
            self.viewer.add_geom(rod_1)

            # Add the secondary stage
            rod_2 = rendering.make_capsule(1, .2)
            rod_2.set_color(1, 0, 0)#Red
            self.pole_transform_2 = rendering.Transform()
            rod_2.add_attr(self.pole_transform_2)
            self.viewer.add_geom(rod_2)

            spring_s1 = rendering.make_capsule(s_L, .05)
            spring_s1.set_color(0, 1, 0)#Green
            self.spring_s1_transform = rendering.Transform()
            spring_s1.add_attr(self.spring_s1_transform)
            self.viewer.add_geom(spring_s1)

            spring_s2 = rendering.make_capsule(s_L, .05)
            spring_s2.set_color(0, 1, 0)
            self.spring_s2_transform = rendering.Transform()
            spring_s2.add_attr(self.spring_s2_transform)
            self.viewer.add_geom(spring_s2)

            spring_s3 = rendering.make_capsule(s_L, .05)
            spring_s3.set_color(0, 1, 0)
            self.spring_s3_transform = rendering.Transform()
            spring_s3.add_attr(self.spring_s3_transform)
            self.viewer.add_geom(spring_s3)

            spring_s4 = rendering.make_capsule(s_L, .05)
            spring_s4.set_color(0, 1, 0)
            self.spring_s4_transform = rendering.Transform()
            spring_s4.add_attr(self.spring_s4_transform)
            self.viewer.add_geom(spring_s4)

            spring_p1 = rendering.make_capsule(s_L, .05)
            spring_p1.set_color(0, 1, 0)
            self.spring_p1_transform = rendering.Transform()
            spring_p1.add_attr(self.spring_p1_transform)
            self.viewer.add_geom(spring_p1)

            spring_p2 = rendering.make_capsule(s_L, .05)
            spring_p2.set_color(0, 1, 0)
            self.spring_p2_transform = rendering.Transform()
            spring_p2.add_attr(self.spring_p2_transform)
            self.viewer.add_geom(spring_p2)

            spring_p3 = rendering.make_capsule(s_L, .05)
            spring_p3.set_color(0, 1, 0)
            self.spring_p3_transform = rendering.Transform()
            spring_p3.add_attr(self.spring_p3_transform)
            self.viewer.add_geom(spring_p3)

            spring_p4 = rendering.make_capsule(s_L, .05)
            spring_p4.set_color(0, 1, 0)
            self.spring_p4_transform = rendering.Transform()
            spring_p4.add_attr(self.spring_p4_transform)
            self.viewer.add_geom(spring_p4)

            wheel = rendering.make_circle(wheel_size)
            wheel.set_color(0, 1, 0)
            self.wheel = rendering.Transform()
            wheel.add_attr(self.wheel)
            self.viewer.add_geom(wheel)

            damper_s1 = rendering.make_capsule(damper_L, .05)
            damper_s1.set_color(0, 1, 0)
            self.damper_s1 = rendering.Transform()
            damper_s1.add_attr(self.damper_s1)
            self.viewer.add_geom(damper_s1)

            damper_s2 = rendering.make_capsule(0.2, .05)
            damper_s2.set_color(0, 1, 0)
            self.damper_s2 = rendering.Transform()
            damper_s2.add_attr(self.damper_s2)
            self.viewer.add_geom(damper_s2)

            damper_s3 = rendering.make_capsule(0.3, .05)
            damper_s3.set_color(0, 1, 0)
            self.damper_s3 = rendering.Transform()
            damper_s3.add_attr(self.damper_s3)
            self.viewer.add_geom(damper_s3)

            damper_s4 = rendering.make_capsule(0.3, .05)
            damper_s4.set_color(0, 1, 0)
            self.damper_s4 = rendering.Transform()
            damper_s4.add_attr(self.damper_s4)
            self.viewer.add_geom(damper_s4)

            damper_s5 = rendering.make_capsule(0.3, .05)
            damper_s5.set_color(0, 1, 0)
            self.damper_s5 = rendering.Transform()
            damper_s5.add_attr(self.damper_s5)
            self.viewer.add_geom(damper_s5)

            damper_s6 = rendering.make_capsule(0.4, .05)
            damper_s6.set_color(0, 1, 0)
            self.damper_s6 = rendering.Transform()
            damper_s6.add_attr(self.damper_s6)
            self.viewer.add_geom(damper_s6)

            damper_p1 = rendering.make_capsule(damper_L, .05)
            damper_p1.set_color(0, 1, 0)
            self.damper_p1 = rendering.Transform()
            damper_p1.add_attr(self.damper_p1)
            self.viewer.add_geom(damper_p1)

            damper_p2 = rendering.make_capsule(0.2, .05)
            damper_p2.set_color(0, 1, 0)
            self.damper_p2 = rendering.Transform()
            damper_p2.add_attr(self.damper_p2)
            self.viewer.add_geom(damper_p2)

            damper_p3 = rendering.make_capsule(0.3, .05)
            damper_p3.set_color(0, 1, 0)
            self.damper_p3 = rendering.Transform()
            damper_p3.add_attr(self.damper_p3)
            self.viewer.add_geom(damper_p3)

            damper_p4 = rendering.make_capsule(0.3, .05)
            damper_p4.set_color(0, 1, 0)
            self.damper_p4 = rendering.Transform()
            damper_p4.add_attr(self.damper_p4)
            self.viewer.add_geom(damper_p4)

            damper_p5 = rendering.make_capsule(0.3, .05)
            damper_p5.set_color(0, 1, 0)
            self.damper_p5 = rendering.Transform()
            damper_p5.add_attr(self.damper_p5)
            self.viewer.add_geom(damper_p5)

            damper_p6 = rendering.make_capsule(0.6, .05)
            damper_p6.set_color(0, 1, 0)
            self.damper_p6 = rendering.Transform()
            damper_p6.add_attr(self.damper_p6)
            self.viewer.add_geom(damper_p6)

        self.pole_transform_1.set_translation(-0.5,pos_p)
        self.pole_transform_2.set_translation(-0.5,pos_s)

        self.spring_s1_transform.set_translation(0+y_spring,pos_s)
        self.spring_s1_transform.set_rotation(np.arctan2(-0.2,-gap_s))
        
        self.spring_s2_transform.set_translation(s_L*np.cos(np.arctan2(-0.2,-gap_s))+y_spring,pos_s-gap_s*1)
        self.spring_s2_transform.set_rotation(np.arctan2(-0.2, gap_s))

        self.spring_s3_transform.set_translation(0+y_spring,pos_s-gap_s*2)
        self.spring_s3_transform.set_rotation(np.arctan2(-0.2,-gap_s))

        self.spring_s4_transform.set_translation(s_L*np.cos(np.arctan2(-0.2,-gap_s))+y_spring,pos_s-gap_s*3)
        self.spring_s4_transform.set_rotation(np.arctan2(-0.2, gap_s))

        self.spring_p1_transform.set_translation(0+y_spring,pos_p)
        self.spring_p1_transform.set_rotation(np.arctan2(-0.2,-gap_p))
        
        self.spring_p2_transform.set_translation(s_L*np.cos(np.arctan2(-0.2,-gap_p))+y_spring,pos_p-gap_p*1)
        self.spring_p2_transform.set_rotation(np.arctan2(-0.2, gap_p))

        self.spring_p3_transform.set_translation(0+y_spring,pos_p-gap_p*2)
        self.spring_p3_transform.set_rotation(np.arctan2(-0.2,-gap_p))

        self.spring_p4_transform.set_translation(s_L*np.cos(np.arctan2(-0.2,-gap_p))+y_spring,pos_p-gap_p*3)
        self.spring_p4_transform.set_rotation(np.arctan2(-0.2, gap_p))

        self.wheel.set_translation(0,bottom-0.3)

        self.damper_s1.set_translation(0+y_damper,pos_s+0.1)
        self.damper_s1.set_rotation(-np.pi/2.1)

        self.damper_s2.set_translation(0+y_damper-0.05,pos_s+0.1-damper_L)

        self.damper_s3.set_translation(0+y_damper-0.1,pos_p-0.1+damper_L)
        self.damper_s3.set_rotation(np.pi/2.1)

        self.damper_s4.set_translation(0+y_damper+0.15,pos_p-0.1+damper_L)
        self.damper_s4.set_rotation(np.pi/2.1)

        self.damper_s5.set_translation(0+y_damper-0.1,pos_p-0.1+damper_L)

        self.damper_s6.set_translation(0+y_damper,pos_p-0.1)
        self.damper_s6.set_rotation(np.pi/2.1)

        self.damper_p1.set_translation(0+y_damper,pos_p+0.1)
        self.damper_p1.set_rotation(-np.pi/2.1)

        self.damper_p2.set_translation(0+y_damper-0.05,pos_p+0.1-damper_L)

        self.damper_p3.set_translation(0+y_damper-0.1,bottom-0.2+damper_L)
        self.damper_p3.set_rotation(np.pi/2.1)

        self.damper_p4.set_translation(0+y_damper+0.15,bottom-0.2+damper_L)
        self.damper_p4.set_rotation(np.pi/2.1)

        self.damper_p5.set_translation(0+y_damper-0.1,bottom-0.2+damper_L)

        self.damper_p6.set_translation(0+y_damper,bottom-0.4)
        self.damper_p6.set_rotation(np.pi/2.1)

        return self.viewer.render(return_rgb_array = mode=='rgb_array')

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

def angle_normalize(x):
    return (((x+np.pi) % (2*np.pi)) - np.pi)



