import os
import time
import pickle
import numpy as np
import ray
import sys
import traceback

from pathlib import Path
from collections import namedtuple

from expground import settings
from expground.logger import Log
from expground.envs.agent_interface import AgentInterface
from expground.envs.gathering import env_desc_gen
from expground.common.policy_pool import PolicyPool
from expground.utils.data import EpisodeKeys
from expground.envs.vector_env import VectorEnv


DATA_DIR = os.path.expanduser("~/dataset/expground/gathering")
time_string = time.strftime("%Y%m%d-%H%M%S")
Agents = namedtuple("Agents", "lagent, ragent")
YAML = Path(__file__).parent / "gathering_eval.yaml"


def detect_model_num(model_path):
    raw_f_list = model_path.iterdir()
    res = []
    for f_name in raw_f_list:
        if os.path.splitext(f_name)[-1] == ".pt":
            res.append(str(f_name).split("/")[-1])
    res = sorted(res, key=lambda x: int(x.split("-")[-1].split(".")[0]))
    return res[:100]


def load_model_and_dist(algo, game_name, r_or_c):
    model_dir = Path(os.path.join(DATA_DIR, game_name, algo, "models"))
    if r_or_c == "r":
        model_name = detect_model_num(model_dir / "agent_0")
        lpool = PolicyPool.load_from_config(
            model_dir / "agent_0", yaml_path=YAML, model_keys=model_name, aid="agent_0"
        )
        lagent = AgentInterface(
            policy_name=f"{algo}_left",
            policy=lpool,
            action_space=lpool._action_space,
            observation_space=lpool._observation_space,
            is_active=False,
        )
        return lagent, model_name
    else:
        model_name = detect_model_num(model_dir / "agent_1")
        rpool = PolicyPool.load_from_config(
            model_dir / "agent_1", yaml_path=YAML, model_keys=model_name, aid="agent_1"
        )

        ragent = AgentInterface(
            policy_name=f"{algo}_right",
            policy=rpool,
            action_space=rpool._action_space,
            observation_space=rpool._observation_space,
            is_active=False,
        )

        return ragent, model_name


def eval_run(agent_interfaces, env, agent_policy_mapping, eval_num=50):
    env.seed(np.random.randint(0, 1000))

    episode_reward = dict.fromkeys(env.possible_agents, 0.0)
    rets = env.reset()
    _ = [
        interface.reset(policy_id=agent_policy_mapping[aid])
        for aid, interface in agent_interfaces.items()
    ]

    step = 0
    done = False

    observations = {
        aid: agent_interfaces[env.agent_to_group(aid)].transform_observation(obs)
        for aid, obs in rets[EpisodeKeys.OBSERVATION.value].items()
    }

    reward = dict.fromkeys(env.possible_agents, 0.0)
    action_masks = dict.fromkeys(env.possible_agents, None)
    actions = dict.fromkeys(env.possible_agents, None)
    while step < 100 and not done:
        for aid, observation in observations.items():
            action_masks[aid] = (
                rets[EpisodeKeys.ACTION_MASK.value][aid]
                if rets.get(EpisodeKeys.ACTION_MASK.value) is not None
                else None
            )
            actions[aid], _, _ = agent_interfaces[aid].compute_action(
                observation,
                action_masks[aid],
                # policy_id=agent_policy_mapping.get(aid),
                evaluate=True,
            )
        rets = env.step(actions)
        observations = {
            aid: agent_interfaces[env.agent_to_group(aid)].transform_observation(obs)
            for aid, obs in rets[EpisodeKeys.OBSERVATION.value].items()
        }
        done = (
            any(list(rets[EpisodeKeys.DONE.value].values())[0])
            if isinstance(env, VectorEnv)
            else any(list(rets[EpisodeKeys.DONE.value].values()))
        )
        step += 1
        for aid, r in rets[EpisodeKeys.REWARD.value].items():
            reward[aid] += r
    for aid, r in reward.items():
        episode_reward[aid] += np.mean(r)
        # Log.info("\t* end {}th evaluation with average reward: {}".format(i, episode_reward))
    return episode_reward


@ray.remote
def sub(eval_algo, env_desc, game_name, agent_ids, RESULTS_DIR, tag):
    # eval model as column player
    agents = [None, None]
    if tag == "r":
        agents[0], r_model_names = load_model_and_dist(eval_algo, game_name, r_or_c=tag)
        agents[1], c_model_names = load_model_and_dist("psro", game_name, r_or_c="c")
    else:
        agents[0], r_model_names = load_model_and_dist("psro", game_name, r_or_c=tag)
        agents[1], c_model_names = load_model_and_dist(eval_algo, game_name, r_or_c="c")

    matrix = np.zeros((100, 100))
    env = VectorEnv(
        env_desc, num_envs=50, use_remote=True, resource_config={"num_cpus": 0.5}
    )
    try:
        Log.info("eval {} as {} player".format(eval_algo, tag))
        for i, rmodel_name in enumerate(r_model_names):
            for j, cmodel_name in enumerate(c_model_names):
                start = time.time()
                res = eval_run(
                    dict(zip(agent_ids, agents)),
                    env=env,
                    agent_policy_mapping=dict(
                        zip(agent_ids, [rmodel_name, cmodel_name])
                    ),
                )
                if tag == "r":
                    reward = res["agent_0"]
                else:
                    reward = res["agent_1"]
                Log.info(
                    "\t* {} model={} reward={} time={}".format(
                        tag, (rmodel_name, cmodel_name), reward, time.time() - start
                    )
                )
                matrix[i, j] = reward
    except Exception as e:
        Log.error(traceback.format_exc())
    finally:
        path = os.path.join(RESULTS_DIR, "{}.pkl".format(tag))
        with open(path, "wb") as f:
            pickle.dump(matrix, f)
        Log.info("save eval matrix dict to: {}".format(path))
    return matrix


def play_against(eval_algo, game_name):
    env_desc = env_desc_gen(
        game_name, scenario_config={"n_agents": 2, "map_name": game_name}
    )

    agent_ids = env_desc["config"]["possible_agents"]

    RESULTS_DIR = os.path.join(
        settings.BASE_DIR,
        "results/gathering_{}_{}/{}".format(game_name, eval_algo, time_string),
    )

    if not os.path.exists(RESULTS_DIR):
        os.makedirs(RESULTS_DIR)

    res = ray.get(
        [
            sub.remote(eval_algo, env_desc, game_name, agent_ids, RESULTS_DIR, "r"),
            sub.remote(eval_algo, env_desc, game_name, agent_ids, RESULTS_DIR, "c"),
        ]
    )
    path = os.path.join(RESULTS_DIR, "matrix_dict.pkl")
    with open(path, "wb") as f:
        pickle.dump({"r": res[0], "c": res[1]}, f)
    Log.info("save eval matrix dict to: {}".format(path))


if __name__ == "__main__":
    if not ray.is_initialized():
        ray.init()
    eval_algo = sys.argv[1]
    game_name = sys.argv[2]
    play_against(eval_algo, game_name)
    ray.shutdown()
