#!/usr/bin/env python

import collections
import ray
from utils.spaces.space_utils import flatten_to_single_ndarray
from ray.tune.registry import register_trainable

from agents.registry import ALGORITHMS, get_trainer_class
from env.football.self_play_env import SelfPlayFootballEnv

for key in list(ALGORITHMS.keys()):
    register_trainable(key, get_trainer_class(key))

EXAMPLE_USAGE = """
Example usage via executable:
    python evaluate.py
"""

CHECKPOINTS = {
    "ALP_10k1": "ray_results/gfootball-1v1-self-play-alp/seed123_iter100/checkpoint_010000/checkpoint-10000",
    "ALP_10k2": "ray_results/gfootball-1v1-self-play-alp/seed123_iter200/checkpoint_010000/checkpoint-10000",
    "ALP_15k1": "ray_results/gfootball-1v1-self-play-alp/seed123_iter100/checkpoint_015000/checkpoint-15000",
    "HFSP_10k1": "ray_results/gfootball-1v1-self-play-hfsp/seed123_iter100/checkpoint_010000/checkpoint-10000",
    "HFSP_10k2": "ray_results/gfootball-1v1-self-play-hfsp/seed123_iter200/checkpoint_010000/checkpoint-10000",
    "HFSP_23k1": "ray_results/gfootball-1v1-self-play-hfsp/seed123_iter100/checkpoint_023000/checkpoint-23000",
    "HFSP_23k2": "ray_results/gfootball-1v1-self-play-hfsp/seed123_iter200/checkpoint_023000/checkpoint-23000",
    "PFSP_10k": "ray_results/gfootball-1v1-self-play-pfsp/seed123/checkpoint_010000/checkpoint-10000",
    "PFSP_15k": "ray_results/gfootball-1v1-self-play-pfsp/seed123/checkpoint_015000/checkpoint-15000",
    "PFSP_33k": "ray_results/gfootball-1v1-self-play-pfsp/seed123/checkpoint_033000/checkpoint-33000",
}

CHECKPOINT_LEFT_PATH = CHECKPOINTS["PFSP_10k"]
CHECKPOINT_RIGHT_PATH = CHECKPOINTS["HFSP_23k"]
CHECKPOINT_RIGHT_PATH = "bot"
NUM_EPISODES = 100


# Note: if you use any custom models or envs, register them here first, e.g.:
#
# from examples.env.parametric_actions_cartpole import \
#     ParametricActionsCartPole
# from examples.model.parametric_actions_model import \
#     ParametricActionsModel
# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
# register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))


def rollout(
    agent_left,
    agent_right,
    num_episodes=0,
    number_left_agents=1,
    number_right_agents=1,
):
    if agent_left is None:
        number_left_agents = 0
    if agent_right is None:
        number_right_agents = 0
    env = SelfPlayFootballEnv(
        env_name="1_vs_1_half_field",
        stacked=False,
        rewards="scoring",
        write_goal_dumps=False,
        write_full_episode_dumps=True,
        render=False,
        write_video=True,
        dump_frequency=20,
        representation="simple115v2",
        number_of_left_players_agent_controls=number_left_agents,
        number_of_right_players_agent_controls=number_right_agents,
        logdir="alp_vs_hfsp",
        other_config_options={
            "action_set": "default"  # "default" = action_set_v1 (19), "v2" = action_set_v2 (19 + 1 built-in ai)
        },
        court_range=0.3,
        in_evaluation=True,
    )

    if agent_left and agent_right:
        policy_map_left = agent_left.workers.local_worker().policy_map
        policy_map_right = agent_right.workers.local_worker().policy_map
        action_init = {
            "main_left": flatten_to_single_ndarray(policy_map_left["main_left"].action_space.sample()),
            "main_right": flatten_to_single_ndarray(policy_map_right["main_right"].action_space.sample()),
        }
    elif agent_left:
        policy_map_left = agent_left.workers.local_worker().policy_map
        action_init = {
            "main_left": flatten_to_single_ndarray(policy_map_left["main_left"].action_space.sample()),
        }
    elif agent_right:
        policy_map_right = agent_right.workers.local_worker().policy_map
        action_init = {
            "main_right": flatten_to_single_ndarray(policy_map_right["main_right"].action_space.sample()),
        }


    steps = 0
    episodes = 0
    win = {"left": 0, "right": 0}
    while episodes < num_episodes:
        obs = env.reset()
        prev_actions = {}
        if agent_left:
            prev_actions["left"] = action_init["main_left"]
        if agent_right:
            prev_actions["right"] = action_init["main_right"]
        prev_rewards = collections.defaultdict(lambda: 0.0)
        done = False
        reward = {"left": 0.0, "right": 0.0}
        while not done:
            # print("obs:", obs)
            action_dict = {}
            if agent_left:
                action_left = agent_left.compute_single_action(
                    obs["left"],
                    prev_action=prev_actions["left"],
                    prev_reward=prev_rewards["left"],
                    policy_id="main_left",
                )
                action_dict["left"] = action_left
                prev_actions["left"] = action_left
            if agent_right:
                action_right = agent_right.compute_single_action(
                    obs["right"],
                    prev_action=prev_actions["right"],
                    prev_reward=prev_rewards["right"],
                    policy_id="main_right",
                )
                action_dict["right"] = action_right
                prev_actions["right"] = action_right

            next_obs, reward, done, info = env.step(action_dict)
            for agent_id, r in reward.items():
                prev_rewards[agent_id] = r

            done = done["__all__"]
            steps += 1
            obs = next_obs
        if agent_left and agent_right:
            print(f"Episode #{episodes}: left reward {reward['left']}, right reward {reward['right']}")
            if reward['left'] > reward['right']:
                win["left"] += 1
            else:
                win["right"] += 1
        elif agent_left:
            print(f"Episode #{episodes}: left reward {reward['left']}, right reward {-reward['left']}")
            if reward['left'] > 0:
                win["left"] += 1
            else:
                win["right"] += 1
        elif agent_right:
            print(f"Episode #{episodes}: left reward {-reward['right']}, right reward {reward['right']}")
            if reward['right'] >= 0:
                win["right"] += 1
            else:
                win["left"] += 1
        if done:
            episodes += 1
    print(f"Evaluation results:\n"
          f"\tattacker success-rate: {win['left'] / episodes}\n"
          f"\tdefender success-rate: {win['right'] / episodes}\n")


def main():
    ray.init(local_mode=False)

    left_config = {
        "seed": 123,
        "env": "env.gfootball.self_play_env.SelfPlayFootballEnv",
        "env_config": {
            "env_name": "1_vs_1_half_field",
            "stacked": False,
            "rewards": "scoring",
            "write_goal_dumps": False,
            "write_full_episode_dumps": False,
            "render": False,
            "write_video": False,
            "dump_frequency": 0,
            "representation": "simple115v2",
            "number_of_left_players_agent_controls": 1,
            "number_of_right_players_agent_controls": 1,
            "logdir": "test",
            "court_range": 0.3,
            "in_evaluation": True,
        },
        "num_workers": 1,
        "num_cpus_for_driver": 1,
        "league_config": {
            "type": "agents.league.league.League",
            "coordinator": False,
        },
        "create_env_on_driver": True,
        "multiagent": {
            "policies": ["main_left", "main_right"],
            "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: "main_left" if agent_id == "left" else "main_right",
            "policy_map_capacity": 1000,
        },
    }
    right_config = {
        "seed": 123,
        "env": "env.gfootball.self_play_env.SelfPlayFootballEnv",
        "env_config": {
            "env_name": "1_vs_1_half_field",
            "stacked": False,
            "rewards": "scoring",
            "write_goal_dumps": False,
            "write_full_episode_dumps": False,
            "render": False,
            "write_video": False,
            "dump_frequency": 0,
            "representation": "simple115v2",
            "number_of_left_players_agent_controls": 1,
            "number_of_right_players_agent_controls": 1,
            "logdir": "test",
            "court_range": 0.3,
            "in_evaluation": True,
        },
        "num_workers": 1,
        "num_cpus_for_driver": 1,
        "league_config": {
            "type": "agents.league.league.League",
            "coordinator": False,
        },
        "create_env_on_driver": True,
        "multiagent": {
            "policies": ["main_left", "main_right"],
            "policy_mapping_fn": lambda agent_id, episode, worker, **kwargs: "main_left" if agent_id == "left" else "main_right",
            "policy_map_capacity": 1000,
        },
    }

    # Create the Trainer and load state from checkpoint, if provided.
    from agents.league import LeagueTrainer
    cls = LeagueTrainer
    if CHECKPOINT_LEFT_PATH != "bot":
        agent_left = cls(config=left_config)
        agent_left.restore(CHECKPOINT_LEFT_PATH)
    else:
        agent_left = None
    if CHECKPOINT_RIGHT_PATH != "bot":
        agent_right = cls(config=right_config)
        agent_right.restore(CHECKPOINT_RIGHT_PATH)
    else:
        agent_right = None

    # Do the actual rollout.
    rollout(agent_left, agent_right, NUM_EPISODES)
    if agent_left:
        agent_left.stop()
    if agent_right:
        agent_right.stop()


if __name__ == "__main__":
    main()
