#!/usr/bin/env python

import collections
import ray
from utils.spaces.space_utils import flatten_to_single_ndarray

from env.qnyh_small.multi_agent_env import QnyhSmallSelfplay
from agents.league import LeagueTrainer, PopulationEntropyTrainer

CHECKPOINTS = {
    "PFSP_10k": "old_results_2/qnyh-self-play-variance-pfsp-uniform/League_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_fb704_00000_0_2022-04-21_07-30-56/checkpoint_001000/checkpoint-1000",
    "PFSP_20k": "old_results_2/qnyh-self-play-variance-pfsp-uniform/League_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_fb704_00000_0_2022-04-21_07-30-56/checkpoint_002000/checkpoint-2000",
    "POP_10k": "ray_results/qnyh-self-play-population-entropy-0.0/PopulationEntropy_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_3cf17_00000_0_2022-04-26_09-57-09/checkpoint_001000/checkpoint-1000",
    "POP_20k": "ray_results/qnyh-self-play-population-entropy-0.0/PopulationEntropy_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_3cf17_00000_0_2022-04-26_09-57-09/checkpoint_002000/checkpoint-2000",
    "MEP_10k": "ray_results/qnyh-self-play-population-entropy-0.01/PopulationEntropy_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_4790c_00000_0_2022-04-26_09-57-27/checkpoint_001000/checkpoint-1000",
    "MEP_20k": "ray_results/qnyh-self-play-population-entropy-0.01/PopulationEntropy_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_4790c_00000_0_2022-04-26_09-57-27/checkpoint_002000/checkpoint-2000",
    "MEPv2_20k": "ray_results/qnyh-self-play-population-entropy-new-0.5/PopulationEntropy_env.qnyh_small.multi_agent_env.QnyhSmallSelfplay_dfd6d_00000_0_2022-04-27_08-27-28/checkpoint_002000/checkpoint-2000",
}

LEFT_MODEL_CHECKPOINT_PATHS = [CHECKPOINTS["MEPv2_20k"], CHECKPOINTS["PFSP_10k"]]
# LEFT_MODEL_CHECKPOINT_PATHS = ["bot"]
RIGHT_MODEL_CHECKPOINT_PATHS = [CHECKPOINTS["MEPv2_20k"], CHECKPOINTS["PFSP_10k"]]
# RIGHT_MODEL_CHECKPOINT_PATHS = ["bot"]
LEFT_MODEL_LEFT_POLICY_NAMES = ["left_1", "main_left"]
LEFT_MODEL_RIGHT_POLICY_NAMES = ["right_1", "main_right"]
RIGHT_MODEL_LEFT_POLICY_NAMES = ["left_1", "main_left"]
RIGHT_MODEL_RIGHT_POLICY_NAMES = ["right_1", "main_right"]
NUM_EPISODES = 100
SEEDS = [123]
CLASSES = [PopulationEntropyTrainer, LeagueTrainer]

import numpy as np
import gym
from policy.policy import PolicySpec

TANK_OBS_SPACE = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(37,))
TANK_ACT_SPACE = gym.spaces.Discrete(13)
SHOOTER_OBS_SPACE = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(49,))
SHOOTER_ACT_SPACE = gym.spaces.Discrete(15)

LEFT_OBS_SPACE = SHOOTER_OBS_SPACE
LEFT_ACT_SPACE = SHOOTER_ACT_SPACE
RIGHT_OBS_SPACE = TANK_OBS_SPACE
RIGHT_ACT_SPACE = TANK_ACT_SPACE

results = []


# Note: if you use any custom models or envs, register them here first, e.g.:
#
# from examples.env.parametric_actions_cartpole import \
#     ParametricActionsCartPole
# from examples.model.parametric_actions_model import \
#     ParametricActionsModel
# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
# register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))


def rollout(
        agent_left,
        agent_right,
        left_policy_name,
        right_policy_name,
        num_episodes=0,
):
    if agent_left is None:
        rule_team = 0
    elif agent_right is None:
        rule_team = 1
    else:
        rule_team = -1
    env = QnyhSmallSelfplay(
        races={"left": "Shooter", "right": "Tank"},
        rule_team=rule_team,
        # print_game_log=True,
        hard_ai=True,
    )

    if agent_left and agent_right:
        policy_map_left = agent_left.workers.local_worker().policy_map
        policy_map_right = agent_right.workers.local_worker().policy_map
        action_init = {
            left_policy_name: flatten_to_single_ndarray(
                policy_map_left[left_policy_name].action_space.sample()
            ),
            right_policy_name: flatten_to_single_ndarray(
                policy_map_right[right_policy_name].action_space.sample()
            ),
        }
    elif agent_left:
        policy_map_left = agent_left.workers.local_worker().policy_map
        action_init = {
            left_policy_name: flatten_to_single_ndarray(
                policy_map_left[left_policy_name].action_space.sample()
            ),
        }
    elif agent_right:
        policy_map_right = agent_right.workers.local_worker().policy_map
        action_init = {
            right_policy_name: flatten_to_single_ndarray(
                policy_map_right[right_policy_name].action_space.sample()
            ),
        }

    steps = 0
    episodes = 0
    win = {"left": 0, "right": 0}
    while episodes < num_episodes:
        obs = env.reset()
        prev_actions = {}
        if agent_left:
            prev_actions["left"] = action_init[left_policy_name]
        if agent_right:
            prev_actions["right"] = action_init[right_policy_name]
        prev_rewards = collections.defaultdict(lambda: 0.0)
        done = False
        reward = {"left": 0.0, "right": 0.0}
        while not done:
            # print("obs:", obs)
            action_dict = {}
            if agent_left:
                action_left = agent_left.compute_single_action(
                    obs["left"],
                    prev_action=prev_actions["left"],
                    prev_reward=prev_rewards["left"],
                    policy_id=left_policy_name,
                )
                action_dict["left"] = action_left
                prev_actions["left"] = action_left
            if agent_right:
                action_right = agent_right.compute_single_action(
                    obs["right"],
                    prev_action=prev_actions["right"],
                    prev_reward=prev_rewards["right"],
                    policy_id=right_policy_name,
                )
                action_dict["right"] = action_right
                prev_actions["right"] = action_right

            next_obs, reward, done, info = env.step(action_dict)
            for agent_id, r in reward.items():
                prev_rewards[agent_id] = r

            done = done["__all__"]
            steps += 1
            obs = next_obs
        if agent_left and agent_right:
            # print(f"Episode #{episodes}: left reward {reward['left']}, right reward {reward['right']}")
            if reward["left"] > reward["right"]:
                win["left"] += 1
            else:
                win["right"] += 1
        elif agent_left:
            # print(f"Episode #{episodes}: left reward {reward['left']}, right reward {-reward['left']}")
            if reward["left"] > 0:
                win["left"] += 1
            else:
                win["right"] += 1
        elif agent_right:
            # print(f"Episode #{episodes}: left reward {-reward['right']}, right reward {reward['right']}")
            if reward["right"] >= 0:
                win["right"] += 1
            else:
                win["left"] += 1
        if done:
            episodes += 1
    results.append(
        f"Evaluation results:\n"
        f"\tLeft win-rate: {win['left'] / episodes}\n"
        f"\tRight win-rate: {win['right'] / episodes}\n"
    )


def main():
    ray.init(local_mode=False)
    for SEED in SEEDS:
        for i, CHECKPOINT_LEFT_PATH in enumerate(LEFT_MODEL_CHECKPOINT_PATHS):
            for j, CHECKPOINT_RIGHT_PATH in enumerate(RIGHT_MODEL_CHECKPOINT_PATHS):
                results.append("=" * 50)
                results.append(
                    f"seed: {SEED}\nleft: {CHECKPOINT_LEFT_PATH}\nright: {CHECKPOINT_RIGHT_PATH}"
                )

                # Create the Trainer and load state from checkpoint, if provided.
                from agents.league import LeagueTrainer

                left_cls, right_cls = CLASSES[i], CLASSES[j]
                if CHECKPOINT_LEFT_PATH != "bot":
                    left_config = {
                        "seed": SEED,
                        "env": "env.qnyh_small.multi_agent_env.QnyhSmallSelfplay",
                        "env_config": {
                            "races": {"left": "Shooter", "right": "Tank"},
                        },
                        "num_workers": 1,
                        "num_cpus_for_driver": 1,
                        "league_config": {
                            "type": "agents.league.league.League",
                            "coordinator": False,
                        },
                        "create_env_on_driver": True,
                        "multiagent": {
                            "policies": {
                                LEFT_MODEL_LEFT_POLICY_NAMES[i]: PolicySpec(
                                    None,
                                    LEFT_OBS_SPACE,
                                    LEFT_ACT_SPACE,
                                    None,
                                ),
                                LEFT_MODEL_RIGHT_POLICY_NAMES[i]: PolicySpec(
                                    None,
                                    RIGHT_OBS_SPACE,
                                    RIGHT_ACT_SPACE,
                                    None,
                                ),
                            },
                            "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs:
                            LEFT_MODEL_LEFT_POLICY_NAMES[
                                i
                            ]
                            if agent_id == "left"
                            else LEFT_MODEL_RIGHT_POLICY_NAMES[i],
                            "policy_map_capacity": 1000,
                        },
                    }
                    agent_left = left_cls(config=left_config)
                    agent_left.restore(CHECKPOINT_LEFT_PATH)
                    left_policy_name = LEFT_MODEL_LEFT_POLICY_NAMES[i]
                else:
                    agent_left, left_policy_name = None, None

                if CHECKPOINT_RIGHT_PATH != "bot":
                    right_config = {
                        "seed": SEED,
                        "env": "env.qnyh_small.multi_agent_env.QnyhSmallSelfplay",
                        "env_config": {
                            "races": {"left": "Shooter", "right": "Tank"},
                        },
                        "num_workers": 1,
                        "num_cpus_for_driver": 1,
                        "league_config": {
                            "type": "agents.league.league.League",
                            "coordinator": False,
                        },
                        "create_env_on_driver": True,
                        "multiagent": {
                            "policies": {
                                RIGHT_MODEL_LEFT_POLICY_NAMES[j]: PolicySpec(
                                    None,
                                    LEFT_OBS_SPACE,
                                    LEFT_ACT_SPACE,
                                    None,
                                ),
                                RIGHT_MODEL_RIGHT_POLICY_NAMES[j]: PolicySpec(
                                    None,
                                    RIGHT_OBS_SPACE,
                                    RIGHT_ACT_SPACE,
                                    None,
                                ),
                            },
                            "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs:
                            RIGHT_MODEL_LEFT_POLICY_NAMES[
                                j
                            ]
                            if agent_id == "left"
                            else RIGHT_MODEL_RIGHT_POLICY_NAMES[j],
                            "policy_map_capacity": 1000,
                        },
                    }
                    agent_right = right_cls(config=right_config)
                    agent_right.restore(CHECKPOINT_RIGHT_PATH)
                    right_policy_name = RIGHT_MODEL_RIGHT_POLICY_NAMES[j]
                else:
                    agent_right, right_policy_name = None, None

                # Do the actual rollout.
                rollout(
                    agent_left,
                    agent_right,
                    left_policy_name,
                    right_policy_name,
                    NUM_EPISODES,
                )
                if agent_left:
                    agent_left.stop()
                if agent_right:
                    agent_right.stop()

    for r in results:
        print(r)
    ray.shutdown()


if __name__ == "__main__":
    main()
