import pickle as pkl
import numpy as np
import argparse


def parser_args():    
    parser = argparse.ArgumentParser(description="Convert expert data foramt from the custom format to stable baseline format")
    
    parser.add_argument("-n", "--name", type=str,
                        help="name of environment")
    
    args = parser.parse_args()
    return args 

if __name__=='__main__':
    args = parser_args()
    name = args.name
    with open('./rl_baselines_zoo/experts/{}_expert_demo.pkl'.format(name), 'rb') as f:
        obs_expert, actions_expert, rewards_expert, next_obs_expert, dones_expert = pkl.load(f)

    observations = np.concatenate(obs_expert)
    actions = np.array(actions_expert).reshape(-1, 1)
    rewards = np.stack(rewards_expert)

    dones_expert = np.array(dones_expert).reshape(-1)
    final_idx = np.where(dones_expert==True)[0]
    dones_expert[0] = True
    for idx in final_idx:
        dones_expert[idx] = False
        if not idx==(len(dones_expert)-1):
            dones_expert[idx+1] = True
    episode_starts = dones_expert
    episode_returns = []
    start_idx = np.where(episode_starts)[0]
    for start, end in zip(start_idx[:-1], start_idx[1:]):
        episode_returns.append(rewards[start:end].sum())
    episode_returns = np.array(episode_returns)

    numpy_dict = {
            'actions': actions,
            'obs': observations,
            'rewards': rewards,
            'episode_returns': episode_returns,
            'episode_starts': episode_starts
        } 

    np.savez('./rl_baselines_zoo/experts/{}'.format(name), **numpy_dict)
    print("DONE!!")