import json

import numpy as np
from pvp_iclr_release.utils.carla.pvp_carla_env import PVPEnv
from ray.rllib.policy.sample_batch import SampleBatch


def process_info(info):
    ret = {}
    for k, v in info.items():
        # filter float 32
        if k != "raw_action":
            ret[k] = v
    return ret


if __name__ == '__main__':
    """
    Data = Tuple[o, a, d, r, i]
    """
    num = int(1)  # change eps num here
    pool = []

    env_config = dict(
        obs_mode="birdview",
        force_fps=30,  ###
        disable_vis=False,  ###
        debug_vis=False,
        port=9000,
        disable_takeover=False,  ###
        controller="keyboard",
        env={"visualize": {"location": "lower right"}}
    )

    env = PVPEnv(config=env_config)  # for carla
    success = 0
    episode_reward = []
    episode_cost = []

    total_reward = 0
    total_cost = 0

    obs = env.reset()

    episode_num = 0
    episode_len = []
    last = 0
    while episode_num < num:
        last += 1
        new_obs, reward, done, info = env.step([0, 0, 1, 0, 0])
        action = info["raw_action"]
        total_cost += info["cost"]
        pool.append({SampleBatch.OBS: obs['image'].tolist(), SampleBatch.ACTIONS: action,
                     SampleBatch.NEXT_OBS: new_obs['image'].tolist(),
                     SampleBatch.DONES: done,
                     SampleBatch.REWARDS: reward, SampleBatch.INFOS: process_info(info),
                     })
        obs = new_obs
        total_reward += reward
        if done:
            episode_num += 1
            if info["arrive_dest"]:
                success += 1
            episode_reward.append(total_reward)
            episode_cost.append(total_cost)
            total_reward = 0
            total_cost = 0
            episode_len.append(last)
            print('reset:', episode_num, "this_episode_len:", last, "total_success_rate:", success / episode_num,
                  "mean_episode_reward:{}({})".format(np.mean(episode_reward), np.std(episode_reward)),
                  "mean_episode_cost:{}({})".format(np.mean(episode_cost), np.std(episode_cost)))
            obs = env.reset()
            last = 0
            print('finish {}'.format(episode_num))

    data_set = {"data": pool, "episode_reward": episode_reward, "episode_cost": episode_cost,
                "success_rate": success / episode_num, "episode_len": episode_len}
    try:
        with open('human_traj_' + str(num) + '.json', 'w') as f:
            json.dump(data_set, f)
    except:
        print(data_set)
    print("Dump success")