# AntCustom transferred to other angles
import subprocess
import os
import numpy as np
import h5py

get_env = {
    'HalfCheetah-v2': 'get_cheetah_env',
    'Walker2d-v2': 'get_walker_env',
    'Hopper-v2': 'get_hopper_env',
    'AntCustom-v2': 'get_ant_custom_env'
}
environ = os.environ.copy()
environ["CUDA_VISIBLE_DEVICES"] = "0"
path_prefix = os.environ['UDG_DATA_PATH']

env = 'AntCustom-v2'
l_env, _ = env.split('-')
l_env = l_env.lower()
u_env = l_env.upper()

model_code = 'c96'
model_path = '2023-04-29_09-12-18'
model_type = 'diayn2'
model_step = 'final'
model_file = 'experience.h5'
pool_size = 1000000
rollout_length = 2
penalty_coeff = 1

# find the model with maximum reward
datasets_path = path_prefix+"url-data/data/"+model_type+"/"+env+"/"+model_path+"/"
reward_matrix = np.load(datasets_path+"reward_matrix.np", allow_pickle=True)
num_modes, num_angles = reward_matrix.shape
min_interval = 360 // num_angles
requested_angle = 180
requested_index = ((requested_angle + min_interval//2)//min_interval) % num_angles
# mode = int(np.argmax(reward_matrix[:,requested_index]))
mode = 0
model_step = model_step + "_m{}".format(mode)
model_code = model_code + "_m{}_{}_{}".format(mode, rollout_length, penalty_coeff)
# model_code = model_code + "_{}_{}".format(rollout_length, penalty_coeff)
# model_step = model_step + "_{}_contra".format(requested_angle)
# model_code = model_code + "_mixed_c_{}_{}".format(rollout_length, penalty_coeff)

# relabel data
dataset_path = datasets_path+model_step+"/"+model_file
new_dataset_path = datasets_path+model_step+"/experience_angle_{}.h5".format(requested_angle)
# new_dataset_path = datasets_path+model_step+"/experience.h5"
if not os.path.exists(new_dataset_path):
    subprocess.run(["cp", dataset_path, new_dataset_path])
rad = (requested_angle / 180.0) * np.pi
direction = np.array([np.cos(rad), np.sin(rad)])

print("Start dataset relabeling...")
with h5py.File(new_dataset_path, "r+") as f:
    observations = f['observations']
    actions = f['actions']
    terminals = f['terminals']
    rewards = f['rewards']
    n_samples = observations.shape[0]
    episodes = 0
    episode_rewards = []
    episode_reward = 0.0
    for i in range(n_samples):
        if terminals[i] or i==n_samples-1:
            rewards[i] = 0.0
            episodes += 1
            episode_rewards.append(episode_reward)
            episode_reward = 0.0
            print("Episode {}: reward {}".format(episodes, episode_rewards[-1]))
        else:
            pos1 = observations[i+1][:2]
            pos0 = observations[i][:2]
            dist1 = np.inner(pos1, direction)
            dist0 = np.inner(pos0, direction)
            rewards[i] = (dist1 - dist0) * 100
            episode_reward += rewards[i]
print("Dataset relabeling complete. {} samples, {} episodes, avg reward {}".format(n_samples, episodes, round(np.mean(episode_rewards), 2)))

# register env and write config file
dataset_id = l_env+"-"+model_code+"-"+str(requested_angle)
env_exist = False
with open('url_offline/gym_mujoco/__init__.py', 'r') as f:
    if dataset_id in f.read():
        env_exist = True
if not env_exist:
    with open('url_offline/gym_mujoco/__init__.py', 'a') as f:
        f.write("register(\n")
        f.write("\tid='"+dataset_id+"-v0',\n")
        f.write("\tentry_point='url_offline.gym_mujoco.gym_envs:"+get_env[env]+"',\n")
        f.write("\tmax_episode_steps=1000,\n")
        f.write("\tkwargs={\n")
        f.write("\t\t'ref_min_score': "+u_env+"_RANDOM_SCORE,\n")
        f.write("\t\t'ref_max_score': "+u_env+"_EXPERT_SCORE,\n")
        f.write("\t\t'dataset_path': '"+new_dataset_path+"',\n")
        f.write("\t\t'angle': "+str(requested_angle)+"\n")
        f.write("\t}\n")
        f.write(")\n")
        f.write("\n")
with open('examples/config/url/'+dataset_id.replace('-','_')+'.py', 'w') as f:
    f.write("from .base_mopo import mopo_params, deepcopy\n")
    f.write("\n")
    f.write("params = deepcopy(mopo_params)\n")
    f.write("params.update({\n")
    f.write("\t'domain': '"+l_env+"',\n")
    f.write("\t'task': '"+model_code+"-"+str(requested_angle)+"-v0',\n")
    f.write("\t'exp_name': '"+dataset_id.replace('-','_')+"',\n")
    f.write("\t'log_dir': 'mopo-local/logs/',\n")
    f.write("})\n")
    f.write("params['kwargs'].update({\n")
    f.write("\t'pool_load_path': 'url/"+dataset_id+"-v0',\n")
    f.write("\t'pool_load_max_size': "+str(pool_size)+",\n")
    f.write("\t'rollout_length': "+str(rollout_length)+",\n")
    f.write("\t'penalty_coeff': "+str(penalty_coeff)+"\n")
    f.write("})\n")
subprocess.run(["python", "examples/development/main.py", 
    "--config=examples.config.url."+dataset_id.replace('-','_'), "--gpus=1", "--trial-gpus=1"], env=environ)

print("Task complete. Task code: "+model_code)
print("Dataset stats: {} samples, {} episodes, avg reward {}".format(n_samples, episodes, round(np.mean(episode_rewards), 2)))