"""
Demonstrates RoboSumo with pre-trained policies.
"""
import click
import gym
import os, argparse

import numpy as np
import tensorflow as tf

import robosumo.envs

from robosumo.policy_zoo import LSTMPolicy, MLPPolicy
from robosumo.policy_zoo.utils import load_params, set_from_flat

POLICY_FUNC = {
    "mlp": MLPPolicy,
    "lstm": LSTMPolicy,
}

# fancy click decorators?? I prefer good-old argparse though..
@click.command()
@click.option("--env", type=str,
              default="RoboSumo-Ant-vs-Ant-v0", show_default=True,
              help="Name of the environment.")
@click.option("--policy-names", nargs=2, type=click.Choice(["mlp", "lstm"]),
              default=("mlp", "mlp"), show_default=True,
              help="Policy names.")
@click.option("--param-versions", nargs=2, type=int,
              default=(1, 1), show_default=True,
              help="Policy parameter versions.")
@click.option("--max_episodes", type=int,
              default=5000, show_default=True,
              help="Number of episodes.")
@click.option("--save", type=str, default="", show_default=True,
              help="demonstration save path.")

# parser = argparse.ArgumentParser(description='collect expert data.')
# parser.add_argument("--env", type=str, default="RoboSumo-Ant-vs-Ant-v0", help="Name of the environment.")
# parser.add_argument("--policy-names", nargs=2, type=str, default=("mlp", "mlp"), help="Policy names.")
# parser.add_argument("--param-versions", nargs=2, type=int, default=(1, 1), help="Policy parameter versions.")
# parser.add_argument("--max_episodes", type=int, default=20, help="Number of episodes.")
# parser.add_argument("--save", type=str, default="", help="demonstration save path.")
# args = parser.parse_args()

# import traceback

def main(env, policy_names, param_versions, max_episodes, save=None):
    # try:
    # Construct paths to parameters
    # curr_dir = os.path.dirname(os.path.realpath(__file__))
    # params_dir = os.path.join(curr_dir, "../robosumo/policy_zoo/assets")
    params_dir = os.path.join(robosumo.__path__[0], "policy_zoo/assets")
    agent_names = [env.split('-')[1].lower(), env.split('-')[3].lower()]
    param_paths = []
    for a, p, v in zip(agent_names, policy_names, param_versions):
        param_paths.append(
            os.path.join(params_dir, a, p, "agent-params-v%d.npy" % v)
        )
    if save:
        save = os.path.join(save,
            f"{policy_names[0]}{param_versions[0]}_vs_{policy_names[1]}{param_versions[1]}.npz")

    # Create environment
    env = gym.make(env)
    env.reward_shaping = True # this is my modification to the env
    # TODO: why is there a _adjust_z??? used in agents.py/get_qpos?
    for agent in env.agents:
        agent._adjust_z = -0.5

    tf_config = tf.ConfigProto(
        inter_op_parallelism_threads=1,
        intra_op_parallelism_threads=1)
    sess = tf.Session(config=tf_config)
    sess.__enter__()

    # Initialize policies
    policy = []
    for i, name in enumerate(policy_names):
        scope = "policy" + str(i)
        policy.append(
            POLICY_FUNC[name](scope=scope, reuse=False,
                              ob_space=env.observation_space.spaces[i],
                              ac_space=env.action_space.spaces[i],
                              hiddens=[64, 64], normalize=True)
        )
    sess.run(tf.variables_initializer(tf.global_variables()))

    # Load policy parameters
    params = [load_params(path) for path in param_paths]
    for i in range(len(policy)):
        set_from_flat(policy[i].get_variables(), params[i])

    # Play matches between the agents
    num_episodes, nstep = 0, 0
    total_reward = [0.0  for _ in range(len(policy))]
    total_scores = [0 for _ in range(len(policy))]
    observation = env.reset()
    if not save:
        print("-" * 5 + "Episode %d " % (num_episodes + 1) + "-" * 5)
    else:
        all_obs = [None] * max_episodes
        all_act = [None] * max_episodes
        all_rew = [None] * max_episodes
        all_score = np.zeros(max_episodes)  #[0] * max_episodes
        obs, act, rew = [], [], []

        from tqdm import tqdm
        pbar = tqdm(total=max_episodes)

    while num_episodes < max_episodes:
        if not save:
            env.render(mode="human")
        action = tuple([
            pi.act(stochastic=True, observation=observation[i])[0]
            for i, pi in enumerate(policy)
        ])

        obs.append(observation)
        act.append(action)
        observation, reward, done, infos = env.step(action)
        rew.append(reward)

        nstep += 1
        for i in range(len(policy)):
            total_reward[i] += reward[i]
        if done[0]:
            if save:
                pbar.update()
                # print (reward, infos)
                all_obs[num_episodes] = np.array(obs)
                all_act[num_episodes] = np.array(act)
                all_rew[num_episodes] = np.array(rew)
                obs, act, rew = [], [], []

            score = 0
            for i in range(len(policy)):
                if 'winner' in infos[i]:
                    score = 1 if i == 0 else -1
                    total_scores[i] += 1
                    break
            if save:
                all_score[num_episodes] = score
                if num_episodes % 100 == 0:
                    np.savez_compressed(save, obs=all_obs, act=all_act, rew=all_rew, score=all_score)

            num_episodes += 1
            if not save:
                if score == 0:
                    print("Match tied [T={:3d}]: Scores: {}, Total Episodes: {}"
                          .format(nstep, total_scores, num_episodes))
                else:
                    print("Winner [T={:3d}]: Agent {}, Scores: {}, Total Episodes: {}"
                              .format(nstep, 0 if score == 1 else 1, total_scores, num_episodes))

            observation = env.reset()
            nstep = 0
            total_reward = [0.0  for _ in range(len(policy))]

            for i in range(len(policy)):
                policy[i].reset()

            if not save and num_episodes < max_episodes:
                print("-" * 5 + "Episode %d " % (num_episodes + 1) + "-" * 5)

    if save:
        np.savez_compressed(save, obs=all_obs, act=all_act, rew=all_rew, score=all_score)
    #     sess.__exit__()
    # except:
    #     traceback.print_exc(file=sys.stdout)

if __name__ == "__main__":
    main()
    # if args.save:
    #     import multiprocessing as mp
    #     import itertools
    #     from tqdm import tqdm
    #     policy_names = ["mlp", "lstm"]
    #     param_versions = [1, 2, 3]
    #     versions = itertools.product(policy_names, param_versions)
    #     tasks = list(itertools.combinations_with_replacement(versions, 2))
    #     with mp.Pool(1) as p:
    #         pbar = tqdm(total=len(tasks))
    #         for task in tasks:
    #             policy_names = (task[0][0], task[1][0])
    #             param_versions = (task[0][1], task[1][1])
    #             save_file = os.path.join(args.save,
    #                 f"{policy_names[0]}{param_versions[0]}_vs_{policy_names[1]}{param_versions[1]}.npz")
    #             # print (policy_names, param_versions, save_file)
    #             # main(args.env, policy_names, param_versions, args.max_episodes, save_file)
    #             p.apply_async(main,
    #                 args=(args.env, policy_names, param_versions, args.max_episodes, save_file),
    #                 callback=lambda *args: pbar.update())
    #         p.close()
    #         p.join()
    #         pbar.close()
    # else:
    #     main(**args.__dict__)
