import os
import gym
from robosuite.wrappers import GymWrapper
import robosuite as suite
import argparse
import numpy as np
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
import torch
import random
import numpy as np
import argparse
import os

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=1)
parser.add_argument("--folder", type=str, nargs='?', default='')
parser.add_argument("--env", type=str, nargs='?', default='Door')
parser.add_argument("--device", type=str, nargs='?', default='cuda:0')
parser.add_argument("--num", type=int, nargs='?', default=int(1e5))

args = parser.parse_args()

model1 = SAC.load(f'{args.folder}/source/logs/{args.env}/{str(args.seed)}/best_model', device = args.device)


vec_env = DummyVecEnv([lambda: GymWrapper(
            suite.make(
                args.env,
                robots="Panda", 
                use_camera_obs=False, 
                has_offscreen_renderer=False, 
                has_renderer=False,  
                reward_shaping=True, 
                control_freq=20, 
            )
        )]*1)
seed_everything(args.seed)
vec_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)

state_list = []
action_list = []
sample_num = args.num

for i in range(sample_num):
    obs = vec_env.reset()
    dones = False
    reward = 0
    
    while not dones:
        action, _ = model1.predict(obs)
        state_list.append(obs[0])
        action_list.append(action[0])
        obs_next, rewards, dones, info = vec_env.step(action)
        obs = obs_next
        reward += rewards
    print(f'episode: {i}, reward: {reward}, {len(action_list)}/{sample_num}')
    if(len(action_list) >= sample_num):
        break

os.makedirs(f'data/{args.env}/{args.seed}/',exist_ok=True)
np.save(f'{args.folder}/source/data/{args.env}/{args.seed}/state.npy', np.array(state_list))
np.save(f'{args.folder}/source/data/{args.env}/{args.seed}/action.npy', np.array(action_list))