import os
import tempfile
import time
import gym
import gfootball

try:
    from gfootball import env as football_env
except ImportError as e:
    raise e(
        "Please install Google football evironment before use: https://github.com/google-research/football"
    ) from None

from expground.types import AgentID, Dict
from expground.logger import Log
from expground.envs import Environment
from expground.utils.data import EpisodeKeys


class GRF(Environment):
    def __init__(self, **configs):
        super(GRF, self).__init__(**configs)

        n_right_agents = self._configs["scenario_config"]["n_right_agents"]
        n_left_agents = self._configs["scenario_config"]["n_left_agents"]
        num_agents = n_right_agents + n_left_agents
        stacked = self._configs["scenario_config"].get("stacked", False)
        level = self._configs["scenario_config"]["level"]

        self._env = gfootball.env.create_environment(
            env_name=level,
            stacked=stacked,
            logdir=os.path.join(tempfile.gettempdir(), "expground_test"),
            write_goal_dumps=False,
            write_full_episode_dumps=False,
            render=self._configs["scenario_config"]["render"],
            dump_frequency=0,
            number_of_left_players_agent_controls=n_left_agents,
            number_of_right_players_agent_controls=n_right_agents,
            channel_dimensions=(42, 42),
        )

        action_space = gym.spaces.Discrete(self.env.action_space.nvec[0])
        observation_space = gym.spaces.Box(
            low=self.env.observation_space.low[0],
            high=self.env.observation_space.high[0],
            dtype=self.env.observation_space.dtype,
        )

        self.num_agents = num_agents

        # GRF index agent from left to right
        self._possible_agents = ["agent_%d" % x for x in range(self.num_agents)]
        self._aciton_spaces = dict.fromkeys(self._possible_agents, action_space)
        self._observation_spaces = dict.fromkeys(
            self._possible_agents, observation_space
        )

        self.single_agent_observation_space = observation_space
        self.single_agent_action_space = action_space

        self._trainable_agents = None
        self._team_left = self._possible_agents[:n_left_agents]
        self._team_right = self._possible_agents[n_left_agents:]
        self._agent_to_group = dict.fromkeys(self._team_left, "left")
        self._agent_to_group.update(dict.fromkeys(self._team_right, "right"))
        assert len(self._agent_to_group) == num_agents, len(self._agent_to_group)

    def agent_to_group(self, agent_id) -> str:
        return self._agent_to_group[agent_id]

    @property
    def possible_agents(self):
        return self._possible_agents

    @property
    def action_spaces(self) -> Dict[AgentID, gym.Space]:
        return self._aciton_spaces

    @property
    def observation_spaces(self) -> Dict[AgentID, gym.Space]:
        return self._observation_spaces

    def step(self, action_dict):
        actions = []
        for key, value in sorted(action_dict.items()):
            actions.append(value)
        o, r, d, i = self.env.step(actions)
        rewards = {}
        obs = {}
        infos = {}
        for pos, key in enumerate(sorted(action_dict.keys())):
            infos[key] = i
            if self.num_agents > 1:
                rewards[key] = r[pos]
                obs[key] = o[pos]
            else:
                rewards[key] = r
                obs[key] = o
        dones = dict.fromkeys(self.possible_agents, d)
        return {
            EpisodeKeys.OBSERVATION.value: obs,
            EpisodeKeys.REWARD.value: rewards,
            EpisodeKeys.DONE.value: dones,
            EpisodeKeys.INFO.value: infos,
        }

    def reset(self):
        original_obs = self.env.reset()
        obs = {}
        for x in range(self.num_agents):
            if self.num_agents > 1:
                obs["agent_%d" % x] = original_obs[x]
            else:
                obs["agent_%d" % x] = original_obs
        return {EpisodeKeys.OBSERVATION.value: obs}

    def write_dump(self, key: str):
        self.env.write_dump(key)


if __name__ == "__main__":
    import numpy as np

    from expground.algorithms.ppo.vf_share import VisionPPO

    render = False
    num_seg = 5
    scenario_config = {
        "n_right_agents": num_seg,  # 11,
        "n_left_agents": num_seg,  # 11,
        "level": "5_vs_5",  # "11_vs_11_stochastic",
        "render": render,
    }
    env = GRF(env_id="GRF11v11", scenario_config=scenario_config)

    agents = env.possible_agents
    action_space = env.single_agent_action_space
    observation_space = env.single_agent_observation_space

    action_spaces = env.action_spaces
    observation_spaces = env.observation_spaces
    rets = env.reset()
    n_frame = 0

    # init two policy for left and right
    policies = {
        "left": VisionPPO(
            observation_space, action_space, {"network": "cnn"}, {"use_cuda": True}
        ),
        "right": VisionPPO(
            observation_space, action_space, {"network": "cnn"}, {"use_cuda": False}
        ),
    }

    agent_mapping = lambda x: "left" if int(x.split("_")[-1]) < num_seg else "right"

    def compute_action(observation: Dict[AgentID, np.ndarray]):
        # pack observations
        keys = agents
        left_obs = np.stack([observation[k] for k in keys[:num_seg]])
        right_obs = np.stack([observation[k] for k in keys[num_seg:]])
        res = {}
        left_action, _, _ = policies["left"].compute_action(left_obs, None, False)
        right_action, _, _ = policies["right"].compute_action(right_obs, None, False)
        res = dict(zip(keys, left_action + right_action))
        return res

    try:
        start = time.time()
        while True:
            actions = compute_action(rets[EpisodeKeys.OBSERVATION.value])
            rets = env.step(actions)
            done = rets[EpisodeKeys.DONE.value]
            obs = rets[EpisodeKeys.OBSERVATION.value]
            n_frame += 1
            if n_frame % 10 == 0:
                cur_time = time.time()
                rewards = list(rets[EpisodeKeys.REWARD.value].values())
                Log.info(
                    "render={} FPS: {:.3} reward: {:.3} {:.3} {:.3}".format(
                        render,
                        n_frame / (cur_time - start),
                        np.mean(rewards),
                        np.max(rewards),
                        np.min(rewards),
                    )
                )
            # print("obs keys", len(obs))
            if any(done.values()):
                res = env.reset()
    except KeyboardInterrupt:
        Log.warning("Game stopped, writing dump...")
        env.write_dump("shutdown")
        exit(1)
