import os
os.environ["MUJOCO_GL"] = "osmesa"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"

import argparse


# from envs.reach_obstacle import FetchReachObstacleEnv
# from envs.push_obstacle import FetchPushObstacleEnv
# from envs.slide_obstacle import FetchSlideObstacleEnv
# from envs.pick_and_place_obstacle import FetchPickAndPlaceObstacleEnv
import torch
from rl_modules.gcsl_agent import GCSL
import gymnasium as gym
from envs import register_envs
from mpi_utils.normalizer import normalizer
import numpy as np
from raDT.baselines.envs.maps_obstacle import *
from raDT.constants import *

OFFLINE_DATA_PATH = HOME_PATH_BASELINES + "offline_data/"

parser = argparse.ArgumentParser()

parser.add_argument('--env_name', type=str, default='PointMazeObstacle', help='the environment name')
parser.add_argument('--num_timesteps', type=int, default=2e6, help='the number of random actions to take')
parser.add_argument('--num_avoid', type=int, default=1, help='the number of avoid states in maze')
parser.add_argument('--maze', type=str, default="U_MAZE", help='the maze type')
parser.add_argument('--max_episode_steps', type=int, default=300, help='number of steps per episode in training')
parser.add_argument('--suffix', type=str, default="", help='suffix to dataset')

args = parser.parse_args()

env_name = args.env_name
num_timesteps = args.num_timesteps

register_envs()

env = gym.make(f'{env_name}-v0', maze_map=eval(args.maze), render_mode = 'rgb_array', num_avoid=args.num_avoid, max_episode_steps=args.max_episode_steps)

obs, info = env.reset()

o_ll, ag_ll, g_ll, u_ll, r_ll, c_ll = [], [], [], [], [], []
o_list, ag_list, g_list, u_list, r_list, c_list = [], [], [], [], [], []

obs, info = env.reset()
o_list.append(obs["observation"])
ag_list.append(obs["achieved_goal"])
g_list.append(obs["desired_goal"])

# env.start_video_recorder()
terminated, truncated = False, False
ep_ret, ep_cost = 0, 0

# Define the camera configuration
for _ in range(int(num_timesteps)):
    if _ % 100000 == 0:
        print(_)

    assert env.observation_space.contains(obs)
    act = env.action_space.sample()
    u_list.append(act)
    assert env.action_space.contains(act)
    # modified for Safe RL, added cost
    obs, reward, cost, done, info = env.step(act)
    r_list.append(reward)
    c_list.append(cost) 
    o_list.append(obs["observation"])
    ag_list.append(obs["achieved_goal"])
    g_list.append(obs["desired_goal"])
    # ep_ret += reward
    # ep_cost += costs
    if done:
        o_ll.append(np.stack(o_list))
        ag_ll.append(np.stack(ag_list))
        g_ll.append(np.stack(g_list[:-1]))
        u_ll.append(np.stack(u_list))
        r_ll.append(np.stack(r_list))
        c_ll.append(np.stack(c_list))

        o_list, ag_list, g_list, u_list, r_list, c_list = [], [], [], [], [], []

        obs, info = env.reset()
        o_list.append(obs["observation"])
        ag_list.append(obs["achieved_goal"])
        g_list.append(obs["desired_goal"])

    
    # env.render()
# env.close_video_recorder()
# env.close()

d = {'o': np.stack(o_ll)[:,:-1,:],
     'ag': np.stack(ag_ll)[:,:-1,:],
     'g': np.stack(g_ll)[:,:-1,:],
     'u': np.stack(u_ll)[:,:-1,:],
     'r': np.stack(r_ll)[:,:-1,np.newaxis],
     'c': np.stack(c_ll)[:,:-1,np.newaxis]}

import pickle

with open(OFFLINE_DATA_PATH + f'random/{env_name}/buffer_numavoid{args.num_avoid}{args.suffix}.pkl', 'wb') as handle:
    pickle.dump(d, handle)
