import gym
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../rl_zoo3/")
from action_space_transform_wrappers import ActionRedundancyWrapper, ActionRedundancyWrapper2nd
# from stable_baselines3.common.buffers import ReplayBuffer, RolloutBuffer, DictRolloutBuffer
import torch
from scipy.special import betainc, betaincinv
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import torch as th
from gym import spaces
from torch import nn

from stable_baselines3.common.policies import BasePolicy#, ContinuousCritic

from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.type_aliases import Schedule
import numpy as np
from torch.nn import functional as F

from buffers import ActionProbReplayBuffer

if th.cuda.is_available():
    device = th.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = th.device("cpu")
    print("CUDA is not available. Using CPU.")



# Step 1
env_id_list = ["Ant-v3",
               "HalfCheetah-v3",
               "Hopper-v3",
               "Humanoid-v3",
               "HumanoidStandup-v4",
               "Walker2d-v3",
               "InvertedDoublePendulum-v2",
               ]

env_id = "HalfCheetah-v3"#"Swimmer-v2"#

env = gym.make(env_id)

nleft = 2.5
nright = 2.5
env = ActionRedundancyWrapper(env,nleft=nleft,nright=nright)
# env = ActionRedundancyWrapper2nd(env,nleft=nleft,nright=nright)

last_original_obs = env.reset()

total_timestep = 100000
num_collected_steps = 0

replay_buffer = ActionProbReplayBuffer(
    total_timestep,
    env.observation_space,
    env.action_space,
    handle_timeout_termination=False,)

while num_collected_steps < total_timestep:
    real_action = env.action_space.sample()
    action_log_prob = 0
    # from biased_sample import biased_sample
    # real_action, action_log_prob = biased_sample(action_dim=env.action_space.shape[-1], return_log_prob=True)

    # print(action)
    new_obs, rewards, dones, infos = env.step(real_action)

    replay_buffer.add(
        last_original_obs,
        new_obs,
        real_action,
        rewards,
        dones,
        infos,
        action_log_prob,
    )

    last_original_obs = new_obs
    last_episode_starts = dones

    # dones = True

    if dones:
        # env.seed(0)
        last_original_obs = env.reset()
        # print(last_original_obs)
        print("Done at:"+str(num_collected_steps))

    num_collected_steps += 1

replay_buffer.calculate_sub_indices()

import time
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

# file_name = env_id + "_" + "L" + str(nleft) + "R" + str(nright) + "_" + str(total_timestep) + "_" + time_flag
file_name = env_id + "_saturation_" + "L" + str(nleft) + "R" + str(nright) + "_" + str(total_timestep) + "_" + time_flag

path = "/mnt/zdy/" + file_name

save_to_pkl(path, replay_buffer, verbose=2)
print("Save to "+ path)