import collections
import time
import ray
import numpy as np

from collections import defaultdict

from expground.logger import Log, monitor
from expground.types import Dict, Tuple, PolicyID, AgentID, Any, Sequence, List
from expground.utils.sampler import Sampler
from expground.utils.data import Episode, EpisodeKeys
from expground.common.policy_pool import PolicyPool
from expground.envs.agent_interface import AgentInterface
from expground.envs import vector_env, Environment


# standard time step description, for sequential rollout
_TimeStep = collections.namedtuple(
    "_TimeStep", "observation, action_mask, reward, action, done, action_dist, logits"
)


def simultaneous_rollout(
    sampler: Sampler,
    agent_interfaces: Dict[str, AgentInterface],
    env_description: Dict[str, Any],
    fragment_length: int,
    max_step: int,
    agent_policy_mapping: Dict[str, PolicyID] = None,
    evaluate: bool = False,
    agent_filter: Sequence = None,
    episodic: bool = True,
    train_every: int = 1,
    render: bool = False,
    seed: int = None,
    env: type = None,
    max_episode: int = 10,
):
    """Currently, support only atari games."""

    groups = env_description["config"].get("group", None)

    if env is None:
        Log.debug("Creating game %s", env_description["config"]["env_id"])
        env = env_description["creator"](**env_description["config"])
    env.seed(seed)

    # Log.debug("Env specs: %s", env.observation_spaces)
    # Log.debug("Action specs: %s", env.action_spaces)

    mean_episode_reward = defaultdict(list)
    mean_episode_len = []
    episode_num = 0
    total_cnt = 0
    win_rate = defaultdict(list)

    agent_policy_mapping = agent_policy_mapping or {}

    if agent_filter is None:
        if sampler is not None:
            agent_filter = (
                sampler.trainable_agents if groups is None else list(groups.keys())
            )
        else:
            agent_filter = (
                env.possible_agents if groups is None else list(groups.keys())
            )

    start = time.time()
    while total_cnt < fragment_length and episode_num < max_episode:
        rets = env.reset(limits=max_episode)
        _ = [
            interface.reset(policy_id=agent_policy_mapping.get(aid))
            for aid, interface in agent_interfaces.items()
        ]

        done = False
        cnt = 0

        observations = {
            aid: agent_interfaces[env.agent_to_group(aid)].transform_observation(obs)
            for aid, obs in rets[EpisodeKeys.OBSERVATION.value].items()
        }

        episode_reward = defaultdict(lambda: 0.0)
        while not done and cnt < max_step:
            actions, action_dists, logits, action_masks = {}, {}, {}, {}
            if groups is not None:
                # group observations
                for k, _agents in groups.items():
                    grouped_obs = np.stack([observations[k] for k in _agents])
                    a, adist, ls = agent_interfaces[k].compute_action(
                        grouped_obs,
                        None,
                        evaluate=evaluate if k in agent_filter else True,
                    )
                    actions.update(dict(zip(_agents, a)))
                    action_dists.update(dict(zip(_agents, adist)))
                    logits.update(dict(zip(_agents, ls)))
            else:
                for aid, observation in observations.items():
                    action_masks[aid] = (
                        rets[EpisodeKeys.ACTION_MASK.value][aid]
                        if rets.get(EpisodeKeys.ACTION_MASK.value) is not None
                        else None
                    )
                    actions[aid], action_dists[aid], logits[aid] = agent_interfaces[
                        aid
                    ].compute_action(
                        observation,
                        action_masks[aid],
                        # policy_id=agent_policy_mapping.get(aid),
                        evaluate=evaluate if aid in agent_filter else True,
                    )
            rets = env.step(actions)
            # ============== handle next_frame ================
            next_observations = {
                aid: agent_interfaces[env.agent_to_group(aid)].transform_observation(
                    obs
                )
                for aid, obs in rets[EpisodeKeys.OBSERVATION.value].items()
            }
            # record time step
            if sampler:
                transition = {
                    "observation": observations,
                    "reward": rets[EpisodeKeys.REWARD.value],
                    "action": actions,
                    "done": rets[EpisodeKeys.DONE.value],
                    "action_distribution": action_dists,
                    "next_observation": next_observations,
                    "logits": logits,
                }
                sampler.add_transition(
                    vector_mode=isinstance(env, vector_env.VectorEnv), **transition
                )
            for aid, r in rets[EpisodeKeys.REWARD.value].items():
                if isinstance(r, np.ndarray):
                    # print("rff", r.shape)
                    r = np.mean(r)
                episode_reward[aid] += r

            if not episodic and sampler.is_ready() and sampler.size % train_every == 0:
                yield {"timesteps": cnt}
            observations = next_observations
            # print("----------------", sampler, rets[EpisodeKeys.DONE.value])
            done = (
                any(list(rets[EpisodeKeys.DONE.value].values())[0])
                if isinstance(env, vector_env.VectorEnv)
                else any(list(rets[EpisodeKeys.DONE.value].values()))
            )
            cnt += 1
            total_cnt += 1 if isinstance(env, Environment) else env.limits

            if cnt % 100 == 0:
                cur_time = time.time()
                fps = total_cnt / (cur_time - start)
                Log.debug(
                    "FPS: {:.3} TOTAL_CNT: {} MAX_STEP: {} FRAGMENT_LENGTH: {}".format(
                        fps,
                        total_cnt,
                        max_step,
                        fragment_length,
                    )
                )

        episode_num += 1 if isinstance(env, Environment) else env.limits

        winner, max_reward = None, -float("inf")
        for aid, v in episode_reward.items():
            mean_episode_reward[aid].append(v)
            if v > max_reward:
                winner = env.agent_to_group(aid)
                max_reward = v
        for k in agent_filter:
            if k == winner:
                win_rate[winner].append(1)
            else:
                win_rate[k].append(0)

        mean_episode_len.append(cnt)

    mean_episode_reward = {
        aid: sum(v) / len(v)
        for aid, v in mean_episode_reward.items()
        if aid in agent_filter
    }
    mean_episode_len = sum(mean_episode_len) / len(mean_episode_len)
    win_rate = {k: sum(v) / len(v) for k, v in win_rate.items() if k in agent_filter}
    # env.close()

    if evaluate:
        return {
            "reward": mean_episode_reward,
            "episode_len": mean_episode_len,
            "win_rate": win_rate,
            "fps": fps,
        }
    else:
        return {"total_timesteps": total_cnt, "num_episode": episode_num, "FPS": fps}


def sequential_rollout(
    sampler: Sampler,
    agent_interfaces: Dict[AgentID, AgentInterface],
    env_description: Dict[str, Any],
    fragment_length: int,
    max_step: int,
    agent_policy_mapping: Dict[AgentID, PolicyID] = {},
    evaluate: bool = False,
    agent_filter: Sequence = None,
    episodic: bool = False,
    train_every: int = 1,
    render: bool = False,
    seed: int = None,
    env: type = None,
    max_episode: int = 10,
):
    """Rollout function for sequential games. Currently, this function doesn't support episode keys
    customization.

    Args:
        sampler (Sampler): The sampler to collect training data.
        agent_interfaces (Dict[AgentID, AgentInterface]): The dict of agent interfaces.
        env_description (Dict[str, Any]): The environment description.
        fragment_length (int): The maximum of rollout length. Close to num_episode * max_step
        max_step (int): The maximum of one episode length.
        agent_policy_mapping (Dict[AgentID, PolicyID], optional): The agent policy mapping. Defaults to {}.
        evalute (bool): Evalute mode or not.

    Returns:
        Dict: A dict of rollout statistics.
    """

    groups = env_description["config"].get("group", None)

    Log.debug("Creating game %s", env_description["config"]["env_id"])
    if env is None:
        env = env_description["creator"](**env_description["config"])
    env.seed(seed)

    if agent_filter is None:
        if sampler is not None:
            # agent_filter = sampler.trainable_agents
            agent_filter = (
                sampler.trainable_agents if groups is None else list(groups.keys())
            )
        else:
            # agent_filter = env.possible_agents
            agent_filter = (
                env.possible_agents if groups is None else list(groups.keys())
            )

    total_cnt = {agent: 0 for agent in agent_filter}
    mean_episode_reward = defaultdict(list)
    mean_episode_len = defaultdict(list)
    episode_num = 0
    win_rate = defaultdict(list)

    while (
        any(
            [
                agent_total_cnt < fragment_length
                for agent_total_cnt in total_cnt.values()
            ]
        )
        and episode_num < max_episode
    ):
        env.reset()
        _ = [
            interface.reset(policy_id=agent_policy_mapping.get(aid, None))
            for aid, interface in agent_interfaces.items()
        ]
        cnt = defaultdict(lambda: 0)
        agent_episode = defaultdict(list)
        episode_reward = defaultdict(lambda: 0.0)

        Log.debug(
            "\t++ [sequential rollout {}/{}] start new episode ++".format(
                list(total_cnt.values()), fragment_length
            )
        )
        last_action_dist = {player_id: None for player_id in env.agents}
        last_logits = {player_id: None for player_id in env.agents}
        for player_id in env.agent_iter():
            observation, pre_reward, done, info = env.last()
            action_mask = agent_interfaces[env.agent_to_group(player_id)].action_mask(
                observation
            )
            observation = agent_interfaces[
                env.agent_to_group(player_id)
            ].transform_observation(observation)
            if not done:
                action, action_dist, logits = agent_interfaces[
                    player_id
                ].compute_action(
                    observation,
                    action_mask=action_mask,
                    evaluate=evaluate if player_id in agent_filter else True,
                )
            else:
                action = None
                action_dist = None
                logits = None
            env.step(action)
            Log.debug(
                "\t\t[agent={}] action={} action_dist={} logits={} pre_reward={}".format(
                    player_id, action, action_dist, logits, pre_reward
                )
            )
            if sampler and player_id in sampler.trainable_agents:
                # print(observation.shape, action_mask.shape, pre_reward, action, action_dist, logits)
                agent_episode[player_id].append(
                    _TimeStep(
                        observation,
                        action_mask,
                        pre_reward,
                        action
                        if action is not None
                        else env.action_spaces[player_id].sample(),
                        done,
                        action_dist
                        if action_dist is not None
                        else last_action_dist[player_id],
                        logits if logits is not None else last_logits[player_id],
                    )
                )
            last_action_dist[player_id] = action_dist
            last_logits[player_id] = logits
            episode_reward[player_id] += pre_reward
            cnt[player_id] += 1

            if all([agent_cnt >= max_step for agent_cnt in cnt.values()]):
                break
        Log.debug(
            "\t++ [sequential rollout] episode end at step={} ++".format(dict(cnt))
        )
        episode_num += 1

        # add transition for non-ego agents
        if not evaluate:
            for agent, interface in agent_interfaces.items():
                if agent not in agent_filter and isinstance(
                    interface.policy, PolicyPool
                ):
                    interface.policy.add_transition(episode_reward[agent])

        if sampler:
            buffers = {}
            for player, data_tups in agent_episode.items():
                (
                    observations,
                    action_masks,
                    pre_rewards,
                    actions,
                    dones,
                    action_dists,
                    logits,
                ) = tuple(map(np.stack, list(zip(*data_tups))))

                rewards = pre_rewards[1:].copy()
                dones = dones[1:].copy()
                next_observations = observations[1:].copy()
                next_action_masks = action_masks[1:].copy()
                next_logits = logits[1:].copy()

                observations = observations[:-1].copy()
                action_masks = action_masks[:-1].copy()
                actions = actions[:-1].copy()
                action_dists = action_dists[:-1].copy()
                logits = logits[:-1].copy()

                buffers[player] = Episode(
                    observations,
                    actions,
                    rewards,
                    next_observations,
                    action_masks,
                    dones,
                    action_dists,
                    logits,
                    extras={
                        "next_action_mask": next_action_masks,
                        "next_logits": next_logits,
                    },
                ).clean_data()
            sampler.add_batches(buffers)
            if not episodic and sampler.is_ready() and sampler.size % train_every == 0:
                yield {"timesteps": cnt[agent_filter[0]]}

        # pack into batch
        for agent in agent_filter:
            total_cnt[agent] += cnt[agent]

        winner, max_reward = None, -float("inf")
        for k, v in episode_reward.items():
            mean_episode_reward[k].append(v)
            mean_episode_len[k].append(cnt[k])
            if v > max_reward:
                winner = k
                max_reward = v
        for k in agent_filter:
            if k == winner:
                win_rate[winner].append(1)
            else:
                win_rate[k].append(0)

    mean_episode_reward = {
        k: sum(v) / len(v) for k, v in mean_episode_reward.items() if k in agent_filter
    }
    mean_episode_len = {
        k: sum(v) / len(v) for k, v in mean_episode_len.items() if k in agent_filter
    }
    win_rate = {k: sum(v) / len(v) for k, v in win_rate.items() if k in agent_filter}

    if evaluate:
        return {
            "reward": mean_episode_reward,
            "episode_len": mean_episode_len,
            "win_rate": win_rate,
        }
    else:
        res = {
            "total_timesteps": total_cnt[agent_filter[0]],
            "num_episode": episode_num,
        }
        return res


def get_rollout_func(
    type_name: str, ray_mode: bool = False, resources: Dict = None
) -> type:
    """Return rollout function by name.

    Args:
        type_name (str): The type name of rollout function. Could be `sequential` or `simultaneous`.
        ray_mode (bool, optional): Enable ray mode or not. Defaults to False.
        resources (Dict): Computing resources for building ray actor.

    Raises:
        TypeError: Unsupported rollout func type.

    Returns:
        type: Rollout caller.
    """

    handler = None
    if type_name == "simultaneous":
        handler = simultaneous_rollout
    elif type_name == "sequential":
        handler = sequential_rollout
    else:
        raise TypeError("Unsupported rollout func type: %s" % type_name)
    if ray_mode:
        handler = ray.remote(handler, resources=resources or {})
    return handler


class Evaluator:
    def __init__(self, env_desc, n_env: int = 1, use_remote_env: bool = False) -> None:
        """Initialize an evaluator with given environment configuration. Specifically, `env_desc` for instance generation,
        `n_env` indicates the number of environments, >1 triggers VectorEnv. `use_remote_env` for VectorEnv mode only.
        """

        if n_env > 1:
            self.env = vector_env.VectorEnv(env_desc, n_env, use_remote=use_remote_env)
        else:
            self.env = env_desc["creator"](**env_desc["config"])

        self.env_desc = env_desc

    def terminate(self):
        self.env.close()

    def run(
        self,
        policy_mappings: List[Dict[AgentID, PolicyID]],
        max_step: int,
        fragment_length: int,
        agent_interfaces: Dict[AgentID, object],
        rollout_caller: type,
        seed: int = None,
        max_episode: int = 10,
    ) -> Sequence[Tuple[Dict, Dict]]:
        """Accept a sequenc of policy mapping description, then evaluate them sequentially.

        Args:
            policy_mappings (List[Dict[AgentID, PolicyID]]): A sequence of policy mapping, describes the policy selection by all agents.
            max_step (int): Max step of one episode.
            fragment_length (int): Fragment length of a data batch.
            agent_interfaces (Dict[AgentID, object]): A dict of agent interfaces.
            rollout_caller (type): Rollout callback function, could be `simultaneous` or `sequential`.
            env_desc (Dict[str, Any]): Environment description dict, for environment construction.

        Returns:
            Sequence[Tuple[Dict, Dict]]: A sequence of evaluation feedabck, corresponding to policy mappings.
        """

        res = []

        if isinstance(self.env, vector_env.VectorEnv):
            self.env._limits = max_episode
        if policy_mappings is not None:
            for policy_mapping in policy_mappings:
                rets = monitor(enable_time=True, enable_returns=True, prefix="\t")(
                    rollout_caller
                )(
                    None,
                    agent_interfaces=agent_interfaces,
                    env_description=self.env_desc,
                    agent_policy_mapping=policy_mapping,
                    max_step=max_step,
                    fragment_length=fragment_length,
                    evaluate=True,
                    seed=seed,
                    env=self.env,
                    max_episode=max_episode,
                )
                reward = 0
                try:
                    while True:
                        _ = next(rets)
                except StopIteration as e:
                    reward = e.value
                res.append((policy_mapping, reward))
        else:
            rets = monitor(enable_time=True, enable_returns=True, prefix="\t")(
                rollout_caller
            )(
                None,
                agent_interfaces=agent_interfaces,
                env_description=self.env_desc,
                max_step=max_step,
                fragment_length=fragment_length,
                evaluate=True,
                seed=seed,
                env=self.env,
                max_episode=10,
            )
            reward = 0
            try:
                while True:
                    _ = next(rets)
            except StopIteration as e:
                res = [(None, e.value)]
        return res
