import gym
import numpy as np
import torch
from torch import Tensor



Array = np.ndarray


class PendulumEnv(gym.Wrapper): 
    def __init__(self, gym_env_id='Pendulum-v1', target_return=-200):
        if gym.__version__ < '0.18.0':
            gym_env_id = 'Pendulum-v0'
        elif gym.__version__ >= '0.20.0':
            gym_env_id = 'Pendulum-v1'
        gym.logger.set_level(40)  # Block warning
        super(PendulumEnv, self).__init__(env=gym.make(gym_env_id))

        # get_gym_env_info(env, if_print=True)  # use this function to print the env information
        self.env_num = 1  # the env number of VectorEnv is greater than 1
        self.env_name = gym_env_id  # the name of this env.
        self.max_step = 200  # the max step of each episode
        self.state_dim = 3  # feature number of state
        self.action_dim = 1  # feature number of action
        self.if_discrete = False  # discrete action or continuous action
        self.target_return = target_return  # episode return is between (-1600, 0)

      
    def reset(self):
        return self.env.reset().astype(np.float32)

    def step(self, action: np.ndarray):
        state, reward, done, info_dict = self.env.step(action * 2)  # state, reward, done, info_dict
        return state.astype(np.float32), reward, done, info_dict


class GymNormaEnv(gym.Wrapper):
    def __init__(self, env_name: str = 'Hopper-v3'):
        gym.logger.set_level(40)  # Block warning
        super(GymNormaEnv, self).__init__(env=gym.make(env_name))

        if env_name == 'Hopper-v3':
            self.env_num = 1
            self.env_name = env_name
            self.max_step = 1000
            self.state_dim = 11
            self.action_dim = 3
            self.if_discrete = False
            self.target_return = 3000

            # 4 runs
            self.state_avg = torch.tensor([1.3819, -0.0105, -0.3804, -0.1759, 0.1959, 2.4185, -0.0406, -0.0172,
                                           -0.1465, -0.0450, -0.1616], dtype=torch.float32)
            self.state_std = torch.tensor([0.1612, 0.0747, 0.2357, 0.1889, 0.6431, 0.6253, 1.4806, 1.1569, 2.2850,
                                           2.2124, 6.5147], dtype=torch.float32)
        elif env_name == 'Swimmer-v3':
            self.env_num = 1
            self.env_name = env_name
            self.max_step = 1000
            self.state_dim = 8
            self.action_dim = 2
            self.if_discrete = False
            self.target_return = 360.0

            # self.state_avg = torch.zeros(1, dtype=torch.float32)
            # self.state_std = torch.ones(1, dtype=torch.float32)
            # 6 runs
            self.state_avg = torch.tensor([0.5877, -0.2745, -0.2057, 0.0802, 0.0105, 0.0158, -0.0047, -0.0057],
                                          dtype=torch.float32)
            self.state_std = torch.tensor([0.5324, 0.5573, 0.5869, 0.4787, 0.5617, 0.8538, 1.2658, 1.4649],
                                          dtype=torch.float32)
        elif env_name == 'Ant-v3':
            self.env_num = 1
            self.env_name = env_name
            self.max_step = 1000
            self.state_dim = 17
            self.action_dim = 6
            self.if_discrete = False
            self.target_return = 5000
            self.state_avg = torch.tensor([6.3101e-01, 9.3039e-01, 1.1357e-02, -6.0412e-02, -1.9220e-01,
                                           1.4675e-01, 6.7936e-01, -1.2429e-01, -6.3794e-01, -2.9083e-02,
                                           -6.0464e-01, 1.0855e-01, 6.5904e-01, 5.2163e+00, 7.5811e-02,
                                           8.2149e-03, -3.0893e-02, -4.0532e-02, -4.5461e-02, 3.8929e-03,
                                           7.3546e-02, -5.1845e-02, -2.2415e-02, 7.4109e-03, -4.0126e-02,
                                           7.2162e-02, 3.4596e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                           0.0000e+00], dtype=torch.float32)
            self.state_std = torch.tensor([0.1170, 0.0548, 0.0683, 0.0856, 0.1434, 0.3606, 0.2035, 0.4071, 0.1488,
                                           0.3565, 0.1285, 0.4071, 0.1953, 1.2645, 1.0212, 1.1494, 1.6127, 1.8113,
                                           1.3163, 4.3250, 3.2312, 5.4796, 2.4919, 4.3622, 2.3617, 5.3836, 3.0482,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
                                           0.0000, 0.0000, 0.0000], dtype=torch.float32)

      
        elif env_name == 'HalfCheetah-v3':
            self.env_num = 1
            self.env_name = env_name
            self.max_step = 1000
            self.state_dim = 17
            self.action_dim = 6
            self.if_discrete = False
            self.target_return = 5000

            # 2 runs
            self.state_avg = torch.tensor([-0.1786, 0.8515, 0.0683, 0.0049, 0.0143, -0.1074, -0.1226, -0.1223,
                                           3.2042, -0.0244, 0.0103, 0.0679, -0.1574, 0.0661, -0.0098, 0.0513,
                                           -0.0142], dtype=torch.float32)
            self.state_std = torch.tensor([0.1224, 0.6781, 0.3616, 0.3545, 0.3379, 0.4800, 0.3575, 0.3372,
                                           1.3460, 0.7967, 2.2092, 9.1078, 9.4349, 9.4631, 11.0645, 9.3995,
                                           8.6867], dtype=torch.float32)
        elif env_name == 'Walker2d-v3':
            self.env_num = 1
            self.env_name = env_name
            self.max_step = 1000
            self.state_dim = 17
            self.action_dim = 6
            self.if_discrete = False
            self.target_return = 8000

            # 6 runs
            self.state_avg = torch.tensor([1.2954, 0.4176, -0.0995, -0.2242, 0.2234, -0.2319, -0.3035, -0.0614,
                                           3.7896, -0.1081, 0.1643, -0.0470, -0.1533, -0.0410, -0.1140, -0.2981,
                                           -0.6278], dtype=torch.float32)
            self.state_std = torch.tensor([0.1095, 0.1832, 0.1664, 0.2951, 0.6291, 0.2582, 0.3270, 0.6931, 1.1162,
                                           1.0560, 2.7070, 3.1108, 4.4344, 6.4363, 3.1945, 4.4594, 6.0115],
                                          dtype=torch.float32)

        else:
            self.state_avg = torch.zeros(1, dtype=torch.float32)
            self.state_std = torch.ones(1, dtype=torch.float32)
            print(f"{self.__class__.__name__} WARNING: env_name not found {env_name}")

        self.state_std = torch.clamp(self.state_std, 2 ** -4, 2 ** 4)  # todo
        

    def get_state_norm(self, state: Array) -> Tensor:
        state = torch.tensor(state, dtype=torch.float32)
        return (state - self.state_avg) / self.state_std

    def reset(self) -> Tensor:
        state = self.env.reset()
        return self.get_state_norm(state)

    def step(self, action: Array) -> (Tensor, float, bool, dict):
        state, reward, done, info_dict = self.env.step(action)  # state, reward, done, info_dict
        return self.get_state_norm(state), reward, done, info_dict


class HumanoidEnv(gym.Wrapper):  
    def __init__(self, gym_env_id='Humanoid-v3', target_return=8000):
        gym.logger.set_level(40)  # Block warning
        super(HumanoidEnv, self).__init__(env=gym.make(gym_env_id))

        # get_gym_env_info(env, if_print=True)  # use this function to print the env information
        self.env_num = 1  # the env number of VectorEnv is greater than 1
        self.env_name = gym_env_id  # the name of this env.
        self.max_step = 1000  # the max step of each episode
        self.state_dim = 376  # feature number of state
        self.action_dim = 17  # feature number of action
        self.if_discrete = False  # discrete action or continuous action
        self.target_return = target_return  # episode return is between (-1600, 0)

        # 5 runs
        self.state_avg = torch.tensor([1.2027e+00, 9.0388e-01, -1.0409e-01, 4.4935e-02, -2.8785e-02,
                                       2.9601e-01, -3.1656e-01, 3.0909e-01, -4.3196e-02, -1.2750e-01,
                                       -2.6788e-01, -1.1086e+00, -1.1024e-01, 1.2908e-01, -5.8439e-01,
                                       -1.6043e+00, 8.1362e-02, -7.7958e-01, -4.3869e-01, -4.9594e-02,
                                       6.4827e-01, -3.0660e-01, 3.4619e+00, -5.2682e-02, -7.4712e-02,
                                       -5.4782e-02, 4.0784e-02, 1.3942e-01, 1.1000e-01, -1.3992e-02,
                                       9.3216e-02, -1.3473e-01, -7.6183e-02, -3.0072e-01, -1.3914e+00,
                                       -7.6460e-02, 1.6543e-02, -2.1907e-01, -3.8219e-01, -1.0018e-01,
                                       -1.5629e-01, -1.0627e-01, -3.7252e-03, 2.1453e-01, 2.7610e-02,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       1.5680e+00, 1.5862e+00, 1.9913e-01, 9.9125e-03, 5.5228e-03,
                                       -1.0950e-01, -1.2668e-01, 2.9367e-01, 3.5102e+00, 8.4719e+00,
                                       6.4141e-02, 6.6425e-02, 2.3180e-02, -2.3346e-03, 4.5395e-03,
                                       8.0720e-03, -3.0787e-02, -5.2109e-02, 3.2192e-01, 2.0724e+00,
                                       5.9864e-02, 5.0491e-02, 7.7832e-02, -3.2226e-03, 1.7504e-04,
                                       -1.9180e-03, -8.2688e-02, -1.9763e-01, 1.0849e-02, 5.9581e+00,
                                       2.5272e-01, 2.6957e-01, 1.1540e-01, 1.6143e-02, 2.7386e-02,
                                       -6.4959e-02, 2.4176e-01, -4.1101e-01, -8.2298e-01, 4.6070e+00,
                                       6.3743e-01, 7.0587e-01, 1.2301e-01, -4.3697e-04, -4.5899e-02,
                                       -6.8465e-02, 2.5412e-02, -1.7718e-01, -1.2062e+00, 2.6798e+00,
                                       6.8834e-01, 7.6378e-01, 1.2859e-01, -8.0863e-03, -1.0989e-01,
                                       -4.6906e-02, -1.4599e-01, -1.0927e-01, -1.0181e+00, 1.7989e+00,
                                       1.9099e-01, 2.0230e-01, 9.9341e-02, -1.5814e-02, 1.5009e-02,
                                       5.1159e-02, 1.6290e-01, 3.2563e-01, -6.0960e-01, 4.6070e+00,
                                       4.5602e-01, 4.9681e-01, 1.0787e-01, -5.9067e-04, -3.5140e-02,
                                       7.0788e-02, 2.5216e-02, 2.1480e-01, -9.1849e-01, 2.6798e+00,
                                       4.6612e-01, 5.2530e-01, 9.9732e-02, 1.3496e-02, -8.3317e-02,
                                       4.6769e-02, -1.8264e-01, 1.1677e-01, -7.7112e-01, 1.7989e+00,
                                       2.9806e-01, 2.7976e-01, 1.1250e-01, 3.8320e-03, 1.4312e-03,
                                       9.2314e-02, -2.9700e-02, -2.5973e-01, 5.9897e-01, 1.6228e+00,
                                       2.1239e-01, 1.6878e-01, 1.8192e-01, 6.9662e-03, -2.5374e-02,
                                       7.5638e-02, 3.0046e-02, -3.1797e-01, 2.8894e-01, 1.2199e+00,
                                       2.5424e-01, 2.0008e-01, 1.0215e-01, 1.6763e-03, -1.8978e-03,
                                       -8.9815e-02, -5.8642e-03, 3.2081e-01, 4.9344e-01, 1.6228e+00,
                                       1.8071e-01, 1.4553e-01, 1.4435e-01, -1.2074e-02, -1.3314e-02,
                                       -3.5878e-02, 5.3603e-02, 2.7511e-01, 2.0549e-01, 1.2199e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 4.2954e-02, 6.7016e-02, -2.8482e-02, 3.5978e+00,
                                       -4.8962e-02, -5.6775e-02, -4.4155e-02, -1.1466e-01, 1.6261e-01,
                                       3.5054e+00, -5.0701e-02, -4.9236e-02, -4.1256e-02, 1.1351e-01,
                                       1.3945e-01, 3.4389e+00, -4.3797e-02, -3.3252e-02, 2.8187e-02,
                                       -3.3888e-02, -3.5859e-01, 3.5962e+00, -3.8793e-02, -2.0773e-02,
                                       -2.4524e-02, 1.1582e+00, -4.5108e-02, 5.1413e+00, -8.7558e-02,
                                       -5.7185e-01, -2.4524e-02, 1.1582e+00, -4.5108e-02, 5.1413e+00,
                                       -8.7558e-02, -5.7185e-01, 9.9391e-02, -2.4059e-02, -1.7425e-01,
                                       3.4541e+00, -8.4718e-02, 1.8192e-02, 4.4070e-01, 3.9781e-01,
                                       3.5545e-01, 4.3428e+00, -1.8370e-01, -6.5439e-01, 4.4070e-01,
                                       3.9781e-01, 3.5545e-01, 4.3428e+00, -1.8370e-01, -6.5439e-01,
                                       1.5922e-01, 2.0918e-01, -9.8105e-02, 3.7604e+00, -2.9619e-02,
                                       -5.8485e-02, 1.0385e-01, 2.1228e-01, -1.7878e-01, 3.7999e+00,
                                       -7.4080e-02, -5.3348e-02, -2.6477e-01, 4.1909e-01, 2.9927e-02,
                                       3.6885e+00, -1.1708e-01, -6.7030e-02, -2.1599e-01, 3.9669e-01,
                                       6.0856e-03, 3.8305e+00, -8.3960e-02, -1.1403e-01, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       1.6677e+01, 4.9107e+01, -9.6274e+00, -2.9728e+01, -5.9374e+01,
                                       7.3201e+01, -5.8161e+01, -3.6315e+01, 2.7580e+01, 4.1244e+00,
                                       1.1711e+02, -8.4357e+00, -1.0379e+01, 1.0683e+01, 3.3124e+00,
                                       5.4840e+00, 8.2456e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00], dtype=torch.float32)
        self.state_std = torch.tensor([3.7685e-02, 4.5415e-02, 6.8201e-02, 9.5235e-02, 1.2801e-01, 2.2247e-01,
                                       2.2774e-01, 1.9151e-01, 1.0900e-01, 1.8950e-01, 3.8430e-01, 6.4591e-01,
                                       1.1708e-01, 1.7833e-01, 4.0411e-01, 6.1461e-01, 2.8869e-01, 3.0227e-01,
                                       4.4105e-01, 3.1090e-01, 3.5227e-01, 2.9399e-01, 8.6883e-01, 3.8865e-01,
                                       4.2435e-01, 2.4784e+00, 3.5310e+00, 4.3277e+00, 8.6461e+00, 6.9988e+00,
                                       7.2420e+00, 8.6105e+00, 9.3459e+00, 2.6776e+01, 4.3671e+01, 7.4211e+00,
                                       1.0446e+01, 1.4800e+01, 2.2152e+01, 5.7955e+00, 6.3750e+00, 7.0280e+00,
                                       6.4058e+00, 9.1694e+00, 7.0480e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 1.5156e-01, 1.4699e-01, 7.1690e-02, 3.6331e-02, 1.4871e-01,
                                       1.1873e-01, 3.9063e-01, 3.2367e-01, 1.9474e-01, 0.0000e+00, 1.1069e-02,
                                       1.1819e-02, 5.2432e-03, 3.1321e-03, 8.8166e-03, 6.7725e-03, 5.4790e-02,
                                       4.3172e-02, 3.4676e-02, 0.0000e+00, 1.0203e-02, 1.2745e-02, 1.4526e-02,
                                       9.4642e-03, 5.2404e-03, 5.6170e-03, 1.5328e-01, 1.1638e-01, 1.0253e-01,
                                       0.0000e+00, 4.8770e-02, 4.3080e-02, 4.4482e-02, 2.2124e-02, 4.9892e-02,
                                       2.2123e-02, 2.4277e-01, 1.0974e-01, 1.2796e-01, 0.0000e+00, 1.5967e-01,
                                       1.5963e-01, 6.8688e-02, 3.1619e-02, 1.2107e-01, 5.2330e-02, 2.8835e-01,
                                       1.1818e-01, 1.9899e-01, 0.0000e+00, 2.0831e-01, 2.2797e-01, 9.6549e-02,
                                       3.5202e-02, 1.2134e-01, 5.9960e-02, 2.1897e-01, 1.0345e-01, 2.1384e-01,
                                       0.0000e+00, 4.7938e-02, 4.4530e-02, 3.8997e-02, 2.2406e-02, 4.1815e-02,
                                       2.0735e-02, 2.1493e-01, 1.0405e-01, 1.4387e-01, 0.0000e+00, 1.5225e-01,
                                       1.6402e-01, 6.2498e-02, 3.1570e-02, 1.1685e-01, 4.3421e-02, 2.8339e-01,
                                       1.0626e-01, 2.1353e-01, 0.0000e+00, 1.9867e-01, 2.2000e-01, 8.5643e-02,
                                       3.0187e-02, 1.2717e-01, 5.0311e-02, 2.2468e-01, 9.0330e-02, 2.1959e-01,
                                       0.0000e+00, 4.6455e-02, 4.4841e-02, 2.4198e-02, 1.8876e-02, 3.3907e-02,
                                       2.6701e-02, 9.6149e-02, 7.2464e-02, 6.3727e-02, 0.0000e+00, 6.9340e-02,
                                       6.5581e-02, 5.0208e-02, 3.8457e-02, 3.7162e-02, 3.9005e-02, 1.2357e-01,
                                       9.5124e-02, 1.0308e-01, 0.0000e+00, 4.5508e-02, 4.2817e-02, 2.3776e-02,
                                       2.1004e-02, 3.2342e-02, 2.5299e-02, 1.0703e-01, 7.1359e-02, 6.8018e-02,
                                       0.0000e+00, 5.5628e-02, 5.4957e-02, 4.5547e-02, 3.1943e-02, 3.2783e-02,
                                       2.8549e-02, 1.1968e-01, 9.6011e-02, 9.6069e-02, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.7535e+00,
                                       3.3878e+00, 4.3144e+00, 1.3207e+00, 8.2827e-01, 4.9262e-01, 4.4686e+00,
                                       3.7037e+00, 6.4924e+00, 9.9583e-01, 6.8091e-01, 4.7597e-01, 4.3920e+00,
                                       4.7409e+00, 6.0906e+00, 9.3958e-01, 4.9473e-01, 4.9569e-01, 7.5115e+00,
                                       1.5371e+01, 1.1053e+01, 1.2450e+00, 7.6206e-01, 1.0601e+00, 9.2410e+00,
                                       2.3707e+01, 1.0356e+01, 6.6857e+00, 2.4551e+00, 2.8653e+00, 9.2410e+00,
                                       2.3707e+01, 1.0356e+01, 6.6857e+00, 2.4551e+00, 2.8653e+00, 5.3753e+00,
                                       8.6029e+00, 8.1809e+00, 1.1586e+00, 5.8827e-01, 8.2327e-01, 6.3651e+00,
                                       1.1362e+01, 8.7067e+00, 4.3533e+00, 1.4509e+00, 2.1305e+00, 6.3651e+00,
                                       1.1362e+01, 8.7067e+00, 4.3533e+00, 1.4509e+00, 2.1305e+00, 4.5383e+00,
                                       5.4198e+00, 5.3263e+00, 2.0749e+00, 1.5746e+00, 8.2220e-01, 5.7299e+00,
                                       6.2163e+00, 6.0368e+00, 2.1437e+00, 1.8280e+00, 1.2940e+00, 5.5326e+00,
                                       5.0856e+00, 5.3383e+00, 1.7817e+00, 1.5361e+00, 8.9927e-01, 6.1037e+00,
                                       6.5608e+00, 6.2712e+00, 1.9360e+00, 1.6504e+00, 1.1001e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.7051e+02,
                                       1.5721e+02, 1.7507e+02, 1.7297e+02, 1.3840e+02, 5.2837e+02, 2.8931e+02,
                                       1.6753e+02, 1.6898e+02, 5.0561e+02, 3.0826e+02, 2.2299e+01, 2.6949e+01,
                                       2.4568e+01, 2.5537e+01, 2.9878e+01, 2.6547e+01, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
                                       0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], dtype=torch.float32)
        self.state_std = torch.clamp(self.state_std, 2 ** -4, 2 ** 4)
        
    def get_state_norm(self, state: Array) -> Tensor:
        state = torch.tensor(state, dtype=torch.float32)
        return (state - self.state_avg) / self.state_std

    def reset(self) -> Tensor:
        state = self.env.reset()
        return self.get_state_norm(state)

    def step(self, action: Array) -> (Tensor, float, bool, dict):
        state, reward, done, info_dict = self.env.step(action * 2.5)  # state, reward, done, info_dict
        return self.get_state_norm(state), reward, done, info_dict
