# %%
import os
import sys

UE4Binary_SLEEPTIME = 30

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, ROOT_DIR)
os.chdir(ROOT_DIR)
os.environ['UE4Binary_SLEEPTIME'] = str(UE4Binary_SLEEPTIME)
print(f'ROOT DIR: {ROOT_DIR}')

# %%
import subprocess

# %%
from unrealpose.control.envs import make_train_env
# %%
import argparse
import pickle as pkl
import numpy as np
import ray

def parse_args():
    parser = argparse.ArgumentParser(description='eval')
    parser.add_argument('--num-humans', type=int)
    parser.add_argument('--max-num-humans', type=int)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--ckpt-num', type=str, default=None)
    parser.add_argument('--env-name', type=str)
    parser.add_argument('--render-steps', type=int, default=200)
    parser.add_argument('--use-numerical', type=int)
    args = parser.parse_args()
    return args

# %%
from pathlib import Path

cmd_args = parse_args()

if cmd_args is None:
    # Default to take last
    cp_dir = sorted([f for f in os.listdir(cmd_args.ckpt) if "checkpoint" in f], reverse=True)[0]
    cp_num = str(int(cp_dir.rsplit("_")[1]))
else:
    cp_num = str(cmd_args.ckpt_num)
    cp_dir = "checkpoint" + "_" + cmd_args.ckpt_num.zfill(6)

cp_path = os.path.join(cmd_args.ckpt, cp_dir)
config_path = os.path.join(cmd_args.ckpt, 'params.pkl')
with open(config_path, 'rb') as f:
    cp_config = pkl.load(f)

checkpoint_path = os.path.join(cp_path, "checkpoint-"+cp_num)
with open(checkpoint_path, mode='rb') as file:
    checkpoint = pkl.load(file)

# %%
from ray.rllib.agents import ppo, dqn, sac
import pprint

if cmd_args.use_numerical is not None:
    cp_config['env_config']['use_numerical'] = bool(cmd_args.use_numerical)

if cmd_args.env_name is not None:
    cp_config['env_config']['args'].env_name = cmd_args.env_name

if cmd_args.num_humans is not None:
    cp_config['env_config']['args'].num_humans = cmd_args.num_humans

if cmd_args.max_num_humans is not None:
    cp_config['env_config']['args'].max_num_humans = cmd_args.max_num_humans

cp_config['env_config']['UE4Binary_SLEEPTIME'] = str(UE4Binary_SLEEPTIME)
pprint.pprint(cp_config['env_config'])
env = make_train_env(cp_config.get('env_config', dict()), is_training=False)

algo = cp_config['env_config'].get('algo', 'PPO').upper()
if algo == 'PPO':
    policy_class = ppo.PPOTorchPolicy
elif algo == 'DQN':
    policy_class = dqn.DQNTorchPolicy
elif algo == 'SAC':
    policy_class = sac.SACTorchPolicy
elif cp_config['model']['custom_model'] is not None:
    from unrealpose.custom import load_custom_models, CUSTOM_POLICIES_FROM_MODEL_NAMES
    load_custom_models()
    policy_class = CUSTOM_POLICIES_FROM_MODEL_NAMES.get(cp_config['model']['custom_model'], ppo.PPOTorchPolicy)
else:
    raise NotImplementedError

policies = {}
policy_mapping_fn = lambda agent_id: 'default_policy'
if 'multiagent' in cp_config:
    for policy_id in cp_config['multiagent']['policies']:
        policies[policy_id] = policy_class(env.observation_space, env.action_space, cp_config)
    if 'policy_mapping_fn' in cp_config['multiagent']:
        policy_mapping_fn = cp_config['multiagent']['policy_mapping_fn']
else:
    policies['default_policy'] = policy_class(env.observation_space, env.action_space, cp_config)

# %%
import traceback
from unrealpose.utils.file import create_logger
from tqdm import tqdm
from unrealpose.config import config

worker = pkl.loads(checkpoint['worker'])
for policy_id, policy in policies.items():
    state = worker['state'][policy_id]
    state.pop('_optimizer_variables', None)

    try:
        if env.unwrapped.MULTI_AGENT:
            policy.set_state(state)
        else:
            if 'weights' in state:
                weights = state['weights']
            else:
                weights = state
            policy.set_weights(weights)
    except RuntimeError as e:
        print(e, file=sys.stderr)

try:
    logger, final_output_dir, timestap = create_logger(config, 'online')
    render_type = 'offline-save'

    if env.unwrapped.MULTI_AGENT and not cp_config['env_config'].get('force_single_agent', False):
        observations = env.reset()
        agent_ids = list(observations.keys())
        policies_by_agent_id = {agent_id: policies[policy_mapping_fn(agent_id)] for agent_id in agent_ids}
        infos = {agent_id: {} for agent_id in agent_ids}
        states = {agent_id: policy.model.get_initial_state() for agent_id, policy in policies_by_agent_id.items()}
        actions = {}

        mpjpe_3d_list = []
        pck3d_20_list = []
        env.render(mode=render_type, timestap=timestap)
        for idx in tqdm(range(cmd_args.render_steps)):
            for agent_id, policy in policies_by_agent_id.items():
                results = policy.compute_single_action(observations[agent_id],
                                                       state=states[agent_id],
                                                       info=infos[agent_id],
                                                       explore=False)
                actions[agent_id], states[agent_id], *_ = results

            observations, _, dones, infos = env.step(actions)
            mpjpe_3d_list.append(10 * infos[agent_ids[0]]['mpjpe_3d'])
            pck3d_20_list.append(infos[agent_ids[0]]['pck3d_20'])
            print(
                f'MPJPE = {mpjpe_3d_list[-1]:.2f}mm (mean={np.mean(mpjpe_3d_list):.2f}mm, successrate={100 * (np.array(mpjpe_3d_list) < 200).sum() / len(mpjpe_3d_list):.2f}%) \t '
                f'PCK3D(200mm) = {pck3d_20_list[-1] * 100:.2f} (mean={np.mean(pck3d_20_list) * 100:.2f})'
            )
            env.render(mode=render_type, timestap=timestap)
            if dones['__all__']:
                break
    else:
        # observation = env.reset(init_loc_list=[[-300.0, -300.0, 300.0]])
        policy = policies['default_policy']
        observation = env.reset()
        state = policy.model.get_initial_state()
        prev_action = env.action_space.sample()
        prev_action = None
        info = {}

        env.render(mode=render_type, timestap=timestap)
        for idx in tqdm(range(cmd_args.render_steps)):
            results = policy.compute_single_action(observation,
                                                   state=state,
                                                   info=info,
                                                   prev_action=prev_action,
                                                   explore=False)
            action, state, *_ = results
            prev_action = action
            prev_action = None

            observation, _, done, info = env.step(action)
            env.render(mode=render_type, timestap=timestap)
            if done:
                break
except Exception as e:
    traceback.print_exc()
    # print(e.message)
finally:
    try:
        env
    except NameError:
        print('not find object: env')
    else:
        env.close()
        subprocess.run(f'{sys.executable} -m run.scripts.copy2zip_zarr --zarr-path render_data/{timestap}/3d.zarr', shell=True)

print(sys.argv)
