import os
import pickle
import numpy as np
import argparse


def main():
    # parse argument and read data
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-path', type=str, required=True)
    parser.add_argument('--env', type=str, default='canyonrun')
    args = parser.parse_args()

    with open(args.results_path, 'rb') as f:
        data = pickle.load(f)

    if args.env in ['cartpole', 'halfcheetah', 'pendulum']:
        ep_reward = [np.sum([step_v[-3] for step_v in ep_v]) for ep_v in data]
        print(f'Average episode reward: {np.mean(ep_reward)}')
    elif args.env == 'driving':
        crash_rate = []
        for ep_data in data:
            ep_crash_rate = []
            for step_data in ep_data:
                ego_agent_id = step_data['ego_agent_id']
                step_infos = step_data['infos'][ego_agent_id]
                has_crashed = step_infos['out_of_lane'] or step_infos['exceed_max_rot'] or step_infos['crashed']
                ep_crash_rate.append(has_crashed)
            crash_rate.append(ep_crash_rate)
        if True:
            max_episode_len = 100
            complete_ratio = []
            for v in crash_rate:
                if v[-1]: # incomplete
                    cr = len(v) / max_episode_len
                else:
                    cr = 1.
                complete_ratio.append(cr)
            complete_ratio = np.array(complete_ratio)
            print(f'Average complete ratio: {complete_ratio.mean()}')
        else:
            if False:
                crash_rate = np.array([np.max(v) for v in crash_rate]).mean()
            else:
                crash_rate = np.array([vv for v in crash_rate for vv in v]).mean()
            print(f'Average crash rate: {crash_rate.mean()}')
    else:
        raise ValueError(f'Unrecognized environment {args.env}')


if __name__ == "__main__":
    main()
