import os
import subprocess
get_env = {
    'HalfCheetah-v2': 'get_cheetah_env',
    'Walker2d-v2': 'get_walker_env',
    'Hopper-v2': 'get_hopper_env',
    'AntCustom-v2': 'get_ant_custom_env'
}
environ = os.environ.copy()
environ["CUDA_VISIBLE_DEVICES"] = "3"
path_prefix = os.environ['UDG_DATA_PATH']

env = 'HalfCheetah-v2'
l_env, _ = env.split('-')
l_env = l_env.lower()
u_env = l_env.upper()

model_code = 'f7'
model_path = '2023-04-19_15-09-54'
model_type = 'wasserstein'
model_step = 'step2250'
model_file = 'experience.h5'
pool_size = 1000000
rollout_length = 1
penalty_coeff = 1
dataset_path = path_prefix+"url-data/data/"+model_type+"/"+env+"/"+model_path+"/"+model_step+"/"+model_file
dataset_id = l_env+"-"+model_code+"-"+model_step
env_exist = False
with open('url_offline/gym_mujoco/__init__.py', 'r') as f:
    if dataset_id in f.read():
        env_exist = True
if not env_exist:
    with open('url_offline/gym_mujoco/__init__.py', 'a') as f:
        f.write("register(\n")
        f.write("\tid='"+dataset_id+"-v0',\n")
        f.write("\tentry_point='url_offline.gym_mujoco.gym_envs:"+get_env[env]+"',\n")
        f.write("\tmax_episode_steps=1000,\n")
        f.write("\tkwargs={\n")
        f.write("\t\t'ref_min_score': "+u_env+"_RANDOM_SCORE,\n")
        f.write("\t\t'ref_max_score': "+u_env+"_EXPERT_SCORE,\n")
        f.write("\t\t'dataset_path':'"+dataset_path+"'\n")
        f.write("\t}\n")
        f.write(")\n")
        f.write("\n")
with open('examples/config/url/'+dataset_id.replace('-','_')+'.py', 'w') as f:
    f.write("from .base_mopo import mopo_params, deepcopy\n")
    f.write("\n")
    f.write("params = deepcopy(mopo_params)\n")
    f.write("params.update({\n")
    f.write("\t'domain': '"+l_env+"',\n")
    f.write("\t'task': '"+model_code+"-"+model_step+"-v0',\n")
    f.write("\t'exp_name': '"+dataset_id.replace('-','_')+"',\n")
    f.write("\t'log_dir': 'mopo-local/logs/',\n")
    f.write("})\n")
    f.write("params['kwargs'].update({\n")
    f.write("\t'pool_load_path': 'url/"+dataset_id+"-v0',\n")
    f.write("\t'pool_load_max_size': "+str(pool_size)+",\n")
    f.write("\t'rollout_length': "+str(rollout_length)+",\n")
    f.write("\t'penalty_coeff': "+str(penalty_coeff)+"\n")
    f.write("})\n")
subprocess.run(["python", "examples/development/main.py", 
    "--config=examples.config.url."+dataset_id.replace('-','_'), "--gpus=1", "--trial-gpus=1"], env=environ)