# Cheetah transfer to jumping
import subprocess
import os
import numpy as np
import h5py

get_env = {
    'HalfCheetah-v2': 'get_cheetah_env',
    'Walker2d-v2': 'get_walker_env',
    'Hopper-v2': 'get_hopper_env',
    'AntCustom-v2': 'get_ant_custom_env',
    'CheetahJump-v2': 'get_cheetah_jump_env'
}
environ = os.environ.copy()
environ["CUDA_VISIBLE_DEVICES"] = "4"
path_prefix = os.environ['UDG_DATA_PATH']

env = 'CheetahJump-v2'
l_env, _ = env.split('-')
l_env = l_env.lower()
u_env = l_env.upper()

model_code = 'c65'
model_path = '2023-05-06_10-53-36'
model_type = 'wasserstein'
model_step = 'final'
model_file = 'experience.h5'
pool_size = 1000000
rollout_length = 5
penalty_coeff = 5

# find the model with maximum reward
datasets_path = path_prefix+"url-data/data/"+model_type+"/"+env+"/"+model_path+"/"
reward_matrix = np.load(datasets_path+"reward_matrix_minus.np", allow_pickle=True)
num_modes = reward_matrix.shape
z_coef = -15
mode = int(np.argmax(reward_matrix[:]))
# mode = 0
model_step = model_step + "_xm{}".format(mode)
# model_code = model_code + "_xm{}_{}_0d5".format(mode, rollout_length, penalty_coeff)
model_code = model_code + "_xm{}_{}_{}".format(mode, rollout_length, penalty_coeff)
# relabel data
dataset_path = datasets_path+model_step+"/"+model_file
new_dataset_path = datasets_path+model_step+"/experience_z_coef_{}.h5".format(z_coef)
if not os.path.exists(new_dataset_path):
    subprocess.run(["cp", dataset_path, new_dataset_path])

    print("Start dataset relabeling...")
    with h5py.File(new_dataset_path, "r+") as f:
        observations = f['observations']
        actions = f['actions']
        terminals = f['terminals']
        rewards = f['rewards']
        n_samples = observations.shape[0]
        episodes = 0
        episode_rewards = []
        episode_reward = 0.0
        init = True
        for i in range(n_samples):
            if init:
                init_z = observations[i][0]
            rewards[i] += (z_coef * (observations[i][0] - init_z))
            episode_reward += rewards[i]
            if terminals[i]:
                episode_rewards.append(episode_reward)
                episodes += 1
                episode_step = 0
                episode_reward = 0.0
                init = True
                print("Episode {}: reward {}".format(episodes, episode_rewards[-1]))
            else:
                init = False
    print("Dataset relabeling complete. {} samples, {} episodes, avg reward {}".format(n_samples, episodes, round(np.mean(episode_rewards), 2)))

# register env and write config file
dataset_id = l_env+"-"+model_code+"-"+str(z_coef)
env_exist = False
with open('url_offline/gym_mujoco/__init__.py', 'r') as f:
    if dataset_id in f.read():
        env_exist = True
if not env_exist:
    with open('url_offline/gym_mujoco/__init__.py', 'a') as f:
        f.write("register(\n")
        f.write("\tid='"+dataset_id+"-v0',\n")
        f.write("\tentry_point='url_offline.gym_mujoco.gym_envs:"+get_env[env]+"',\n")
        f.write("\tmax_episode_steps=1000,\n")
        f.write("\tkwargs={\n")
        f.write("\t\t'ref_min_score': "+u_env+"_RANDOM_SCORE,\n")
        f.write("\t\t'ref_max_score': "+u_env+"_EXPERT_SCORE,\n")
        f.write("\t\t'dataset_path': '"+new_dataset_path+"',\n")
        f.write("\t\t'z_coef': "+str(z_coef)+"\n")
        f.write("\t}\n")
        f.write(")\n")
        f.write("\n")
with open('examples/config/url/'+dataset_id.replace('-','_')+'.py', 'w') as f:
    f.write("from .base_mopo import mopo_params, deepcopy\n")
    f.write("\n")
    f.write("params = deepcopy(mopo_params)\n")
    f.write("params.update({\n")
    f.write("\t'domain': '"+l_env+"',\n")
    f.write("\t'task': '"+model_code+"-"+str(z_coef)+"-v0',\n")
    f.write("\t'exp_name': '"+dataset_id.replace('-','_')+"',\n")
    f.write("\t'log_dir': 'mopo-local/logs/',\n")
    f.write("})\n")
    f.write("params['kwargs'].update({\n")
    f.write("\t'pool_load_path': 'url/"+dataset_id+"-v0',\n")
    f.write("\t'pool_load_max_size': "+str(pool_size)+",\n")
    f.write("\t'rollout_length': "+str(rollout_length)+",\n")
    f.write("\t'penalty_coeff': "+str(penalty_coeff)+"\n")
    f.write("})\n")
subprocess.run(["python", "examples/development/main.py", 
    "--config=examples.config.url."+dataset_id.replace('-','_'), "--gpus=1", "--trial-gpus=1"], env=environ)
