import argparse
import h5py
import numpy as np
import os

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--env-name', default="AntCustom-v2",
                    help='Mujoco Gym environment (default: AntCustom-v2)')
parser.add_argument('--env-param', type=float, default=30,
                    help='Environment hyper-parameter (default: 30)')
parser.add_argument('--path', default="",
                    help='Offline dataset path')
parser.add_argument('--num_modes', type=int, default=10, metavar='N',
                    help='number of modes (default: 1)')  
parser.add_argument('--suffix', default="",
                    help='Relabeled offline dataset saving path, appended with this suffix')
parser.add_argument('--target_size', type=int, default=1000000, metavar='G',
                    help='target number of transitions for collection (default: 10000)')                
args = parser.parse_args()

path_prefix = os.environ['UDG_DATA_PATH']
model_prefix = path_prefix+"url-data/models/"
model_type = "diayn2"
model_time = "2023-04-29_09-12-18"
model_step = "final"

data_dir = path_prefix+"url-data/data/{}/{}/{}/{}".format(model_type, args.env_name, model_time, model_step)
interval = 30
reward_matrix = np.zeros((args.num_modes, 360//interval))
for m in range(args.num_modes):
# for m in range(6):
    # angle = m*60
    data_path = data_dir+"_m{}/experience.h5".format(m)
    assert os.path.exists(data_path)
    with h5py.File(data_path, "r") as f:
        observations = f['observations']
        actions = f['actions']
        terminals = f['terminals']
        rewards = f['rewards']
        n_samples = observations.shape[0]
        for a in range(180//interval):
            angle = a * interval
            rad = (angle / 180.0) * np.pi
            direction = np.array([np.cos(rad), np.sin(rad)])

            episodes = 0
            episode_step = 0
            episode_reward = 0.0
            episode_rewards = []
            for i in range(n_samples):
                if terminals[i] or i==n_samples-1:
                    episode_rewards.append(episode_reward)
                    episodes += 1
                    episode_step = 0
                    episode_reward = 0.0
                else:
                    pos1 = observations[i+1][:2]
                    pos0 = observations[i][:2]
                    dist1 = np.inner(pos1, direction)
                    dist0 = np.inner(pos0, direction)
                    episode_reward += ((dist1 - dist0) * 100)
            print("Mode {} in angle {}: total episodes {}, avg reward {}".format(m, a, episodes, round(np.mean(episode_rewards), 2)))
            # print("Mixed top and last 2 in angle {}: total episodes {}, avg reward {}".format(angle, episodes, round(np.mean(episode_rewards), 2)))
            reward_matrix[m, a] = np.mean(episode_rewards)
            reward_matrix[m, a + 180//interval] = -reward_matrix[m, a]

matrix_path = "url-data/data/{}/{}/{}/reward_matrix.np".format(model_type, args.env_name, model_time)
reward_matrix.dump(matrix_path)
print("Reward matrix saved to path: "+matrix_path)