import argparse
import h5py
import numpy as np
import os

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="CheetahJump-v2",
                    help='Mujoco Gym environment (default: AntCustom-v2)')
parser.add_argument('--env-param', type=float, default=30,
                    help='Environment hyper-parameter (default: 30)')
parser.add_argument('--path', default="",
                    help='Offline dataset path')
parser.add_argument('--num_modes', type=int, default=5, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--suffix', default="",
                    help='Relabeled offline dataset saving path, appended with this suffix')
parser.add_argument('--target_size', type=int, default=1000000, metavar='G',
                    help='target number of transitions for collection (default: 10000)')                
args = parser.parse_args()

path_prefix = os.environ['UDG_DATA_PATH']
model_prefix = path_prefix+"url-data/models/"
model_type = "wasserstein"
model_time = "2023-05-06_10-53-36"
model_step = "final"

data_dir = path_prefix+"url-data/data/{}/{}/{}/{}".format(model_type, args.env_name, model_time, model_step)
reward_matrix = np.zeros((args.num_modes,))
z_coef = -15.0
for m in range(args.num_modes):
    data_path = data_dir+"_xm{}/experience.h5".format(m)
    assert os.path.exists(data_path)
    with h5py.File(data_path, "r") as f:
        observations = f['observations']
        actions = f['actions']
        terminals = f['terminals']
        rewards = f['rewards']
        n_samples = observations.shape[0]

        episodes = 0
        episode_step = 0
        episode_reward = 0.0
        episode_rewards = []
        init = True
        for i in range(n_samples):
            if init:
                init_z = observations[i][0]
            episode_reward += (rewards[i] + z_coef * (observations[i][0] - init_z))
            if terminals[i]:
                episode_rewards.append(episode_reward)
                episodes += 1
                episode_step = 0
                episode_reward = 0.0
                init = True
            else:
                init = False

        print("Mode {}: total episodes {}, avg reward {}".format(m, episodes, round(np.mean(episode_rewards), 2)))
        reward_matrix[m] = np.mean(episode_rewards)

matrix_path = "url-data/data/{}/{}/{}/reward_matrix_minus.np".format(model_type, args.env_name, model_time)
# reward_matrix.dump(matrix_path)
print("Reward matrix saved to path: "+matrix_path)