import os
import gymnasium as gym
import argparse
#import pybullet_envs
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import count
from collections import deque
import torch.optim as optim
import random
from torch.distributions import Categorical
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import gym
import numpy as np
import argparse
import os
import warnings
from tqdm import trange
import shelve

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def put_transition(buffer, *transition):
    buffer.append(transition)

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=1)
parser.add_argument("--folder", type=str, nargs='?', default='/home/')
parser.add_argument("--env", type=str, nargs='?', default='Hopper-v3')
args = parser.parse_args()

log_folder = args.folder + "source/stable-baselines3-1.7.0/logs/source/seed" + str(args.seed) + "/"

seed_everything(args.seed)
#warnings.filterwarnings("ignore")
vec_env = DummyVecEnv([lambda: gym.make(args.env)])
vec_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)
final_buffer = deque()

action_list = []
next_state_list = []
flatten_list = []
model = SAC.load(log_folder + "best_model")
sample_num = 60000

for i in range(10000):
    obs = vec_env.reset()
    dones = False
    reward = 0
    
    while not dones:
        action, _states = model.predict(obs)
        #flatten = vec_env.sim.get_state().flatten()
        action_list.append(action[0])
        obs_next, rewards, dones, info = vec_env.step(action)
        next_state_list.append(obs)
        put_transition(final_buffer, obs[0], action, rewards, obs_next[0], dones)
        obs = obs_next
        reward += rewards
        if(len(action_list) >= sample_num):
            break
    print(f'episode: {i}, reward: {reward}')
    if(len(action_list) >= sample_num):
        break
np.save('data_action/seed' + str(args.seed) + '/action.npy', np.array(action_list))