import gym, os
import numpy as np


class GymGame:
    def __init__(self, env_name, dt=0.2, terminal_time=4, u_action_radius=1,
                 v_action_radius=1, render_mode=''):
        self.env_name = env_name
        self.env = gym.make(self.env_name, render_mode=render_mode)
        self.u_action_radius = u_action_radius
        self.v_action_radius = v_action_radius

        self.state_dim = self.env.observation_space.shape[0] + 1
        
        if self.env_name == 'HalfCheetah-v4':
            self.u_action_dim = 5
            self.u_action_min = self.u_action_radius * self.env.action_space.low[[0,1,2,3,4]]
            self.u_action_max = self.u_action_radius * self.env.action_space.high[[0,1,2,3,4]]
            self.v_action_dim = 1
            self.v_action_min = self.v_action_radius * self.env.action_space.low[[5]]
            self.v_action_max = self.v_action_radius * self.env.action_space.high[[5]]
        else:
            self.u_action_dim = self.env.action_space.shape[0]
            self.u_action_min = self.u_action_radius * self.env.action_space.low
            self.u_action_max = self.u_action_radius * self.env.action_space.high
            self.v_action_dim = 1
            self.v_action_min = - self.v_action_radius * np.ones(1)
            self.v_action_max = self.v_action_radius * np.ones(1)
            
        self.terminal_time = terminal_time
        self.max_episode_steps = self.env._max_episode_steps
        self.dt = dt
        self.inner_dt = self.env.dt
        self.inner_step_n = int((self.dt + 1e-3) / self.inner_dt)
        return None
    
    
    def reset(self):
        self.env.reset_noise_scale = 0
        self.state, _ = self.env.reset()
        qpos = np.zeros(self.env.data.qpos.size)
        qvel = np.zeros(self.env.data.qvel.size)
        self.state = np.zeros(self.state_dim)
        self.env.set_state(qpos, qvel)
        return self.state

    
    def step(self, u_action, v_action):
        u_action = np.clip(u_action, self.u_action_min, self.u_action_max)
        v_action = np.clip(v_action, self.v_action_min, self.v_action_max)
        
        reward = 0
        for _ in range(self.inner_step_n):
            
            if self.env_name == 'HalfCheetah-v4':
                action = np.array([u_action[0], u_action[1], 
                                   u_action[2], u_action[3], 
                                   u_action[4], v_action[0]])
                env_state, env_reward, env_done, _, _ = self.env.step(action)
            
            else:
                qpos, qvel = self.env.data.qpos, self.env.data.qvel
                qpos, qvel = self.change_qpos_qvel(qpos, qvel, v_action)
                self.env.set_state(qpos, qvel)
                env_state, env_reward, env_done, _, _ = self.env.step(u_action)

            reward += self.get_reward(env_reward, env_done)

        self.state = np.concatenate(([self.state[0] + self.dt], env_state))
        done = self.state[0] + self.dt / 2 > self.terminal_time

        return self.state, reward, done, {}
    
    
    def render(self):
        self.env.render()
        return None
                        
    
    def change_qpos_qvel(self, qpos, qvel, v_action):
        
        if self.env_name == 'InvertedPendulum-v4':
            qvel[1] += v_action[0] * self.inner_dt
        elif self.env_name == 'Swimmer-v4':
            qpos[4] += v_action[0] * self.inner_dt
            
        return qpos, qvel
    
    
    def get_reward(self, env_reward, env_done):
        
        if self.env_name == 'InvertedPendulum-v4':
            reward = - int(1 - env_done) * self.inner_dt
        elif self.env_name == 'Swimmer-v4':
            reward = - env_reward * self.inner_dt
        elif self.env_name == 'HalfCheetah-v4':
            reward = - env_reward * self.inner_dt
            
        return reward
        