import numpy as np
import gfootball.env as football_env
from gfootball.env import observation_preprocessing
from ..multiagentenv import MultiAgentEnv
import gym
import torch as th


class Five_vs_Five(MultiAgentEnv):

    def __init__(
        self,
        dense_reward=False,
        write_full_episode_dumps=False,
        write_goal_dumps=False,
        dump_freq=1000,
        render=False,
        n_agents=4,
        time_limit=500,
        time_step=0,
        obs_dim=48,
        env_name='5_vs_5',
        stacked=False,
        representation="simple115",
        rewards='scoring',
        logdir='football_dumps',
        write_video=True,
        number_of_right_players_agent_controls=0,
        seed=0
    ):
        self.dense_reward = dense_reward
        self.write_full_episode_dumps = write_full_episode_dumps
        self.write_goal_dumps = write_goal_dumps
        self.dump_freq = dump_freq
        self.render = render
        self.n_agents = n_agents
        self.episode_limit = time_limit
        self.time_step = time_step
        self.obs_dim = obs_dim
        self.env_name = env_name
        self.stacked = stacked
        self.representation = representation
        self.rewards = rewards
        self.logdir = logdir
        self.write_video = write_video
        self.number_of_right_players_agent_controls = number_of_right_players_agent_controls
        self.seed = seed

        self.env = football_env.create_environment(
            write_full_episode_dumps=self.write_full_episode_dumps,
            write_goal_dumps=self.write_goal_dumps,
            env_name=self.env_name,
            stacked=self.stacked,
            representation=self.representation,
            rewards=self.rewards,
            logdir=self.logdir,
            render=self.render,
            write_video=self.write_video,
            dump_frequency=self.dump_freq,
            number_of_left_players_agent_controls=self.n_agents,
            number_of_right_players_agent_controls=self.number_of_right_players_agent_controls,
            channel_dimensions=(observation_preprocessing.SMM_WIDTH, observation_preprocessing.SMM_HEIGHT))
        self.env.seed(self.seed)

        obs_space_low = self.env.observation_space.low[0][:self.obs_dim]
        obs_space_high = self.env.observation_space.high[0][:self.obs_dim]

        self.action_space = [gym.spaces.Discrete(
            self.env.action_space.nvec[1]) for _ in range(self.n_agents)]
        self.observation_space = [
            gym.spaces.Box(low=obs_space_low, high=obs_space_high, dtype=self.env.observation_space.dtype) for _ in range(self.n_agents)
        ]

        self.n_actions = self.action_space[0].n

        self.unit_dim = self.obs_dim  # QPLEX unit_dim for cds_gfootball
        # self.unit_dim = 8  # QPLEX unit_dim set like that in Starcraft II

        self.reward_record = [0, 0]
        self.game_mode = 0
        self.old_game_mode = 0
        self.goal_owned_team = -1

    def check_if_ball_out(self):
        # 维护一个变量记录球在哪一方
        # 当球出界时判断最后触球的是我方还是对方

        if self.old_game_mode != self.game_mode and self.game_mode != 0:
            # 判断球出界一瞬间的step
            # -1 = ball not owned, 0 = left team, 1 = right team.
            if self.goal_owned_team == 0:
                # 历史持球是我方（left team）
                return -10
            elif self.goal_owned_team == 1:
                # 历史持球是敌方（right team）
                return 0

        return 0

    def check_if_players_out(self):
        cur_obs = self.env.unwrapped.observation()[0]
        ours_loc = cur_obs['left_team'][-self.n_agents:]
        # field scale: [-1, 1], [-0.4, 0.4]

        ourside_num = np.where(ours_loc[:, 0] > 1)[0].shape[0] + np.where(ours_loc[:, 0] < -1)[
            0].shape[0] + np.where(ours_loc[:, 1] < -0.4)[0].shape[0] + np.where(ours_loc[:, 1] > 0.4)[0].shape[0]

        return -ourside_num

    def get_simple_obs(self, index=-1):
        full_obs = self.env.unwrapped.observation()[0]
        self.old_game_mode = self.game_mode
        self.game_mode = full_obs['game_mode']
        if full_obs['ball_owned_team'] != -1:
            self.goal_owned_team = full_obs['ball_owned_team']
        simple_obs = []

        if index == -1:
            # global state, absolute position
            # cannot control our goal keeper
            simple_obs.append(full_obs['left_team'].reshape(-1))
            simple_obs.append(
                full_obs['left_team_direction'].reshape(-1))

            simple_obs.append(full_obs['right_team'].reshape(-1))
            simple_obs.append(full_obs['right_team_direction'].reshape(-1))

            simple_obs.append(full_obs['ball'])
            simple_obs.append(full_obs['ball_direction'])
            simple_obs.append(np.array(full_obs['game_mode']).reshape(-1))
            simple_obs.append(
                np.array(full_obs['ball_owned_team']).reshape(-1))

        else:
            # local state, relative position
            ego_position = full_obs['left_team'][-self.n_agents +
                                                 index].reshape(-1)
            simple_obs.append(ego_position)
            simple_obs.append((np.delete(
                full_obs['left_team'], index + 1, axis=0) - ego_position).reshape(-1))

            simple_obs.append(
                full_obs['left_team_direction'][-self.n_agents + index].reshape(-1))
            simple_obs.append(np.delete(
                full_obs['left_team_direction'], index + 1, axis=0).reshape(-1))

            simple_obs.append(
                (full_obs['right_team'] - ego_position).reshape(-1))
            simple_obs.append(full_obs['right_team_direction'].reshape(-1))

            simple_obs.append(full_obs['ball'][:2] - ego_position)
            simple_obs.append(full_obs['ball'][-1].reshape(-1))
            simple_obs.append(full_obs['ball_direction'])
            simple_obs.append(np.array(full_obs['game_mode']).reshape(-1))
            simple_obs.append(
                np.array(full_obs['ball_owned_team']).reshape(-1))

        simple_obs = np.concatenate(simple_obs)
        return simple_obs

    def get_global_state(self):
        return self.get_simple_obs(-1)

    def step(self, _actions):
        """Returns reward, terminated, info."""
        if th.is_tensor(_actions):
            actions = _actions.cpu().numpy()
        else:
            actions = _actions
        self.time_step += 1
        _, original_rewards, done, infos = self.env.step(actions.tolist())

        if self.time_step >= self.episode_limit:
            done = True

        if sum(original_rewards) < 0:
            self.reward_record[1] += 1
        if sum(original_rewards) > 0:
            self.reward_record[0] += 1

        dense_reward = self.check_if_ball_out() + self.check_if_players_out()

        if done:
            infos["score"] = self.reward_record[0] - self.reward_record[1]
            return -100 * (self.reward_record[1] - self.reward_record[0]), done, infos

        if sum(original_rewards) == 0:
            return dense_reward, done, infos

        if sum(original_rewards) < 0:
            return -100, done, infos

        return 1000, done, infos

    def get_obs(self):
        """Returns all agent observations in a list."""
        # obs = np.array([self.get_simple_obs(i) for i in range(self.n_agents)])
        obs = [self.get_simple_obs(i) for i in range(self.n_agents)]
        return obs

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return self.get_simple_obs(agent_id)

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.obs_dim

    def get_state(self):
        """Returns the global state."""
        return self.get_global_state()

    def get_state_size(self):
        """Returns the size of the global state."""
        # TODO: in wrapper_grf_3vs1.py, author set state_shape=obs_shape
        return self.obs_dim

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.action_space[0].n

    def reset(self):
        """Returns initial observations and states."""
        self.time_step = 0
        self.env.reset()
        obs = np.array([self.get_simple_obs(i) for i in range(self.n_agents)])

        self.reward_record = [0, 0]

        self.game_mode = 0
        self.old_game_mode = 0
        self.goal_owned_team = -1

        return obs, self.get_global_state()

    def render(self):
        pass

    def close(self):
        self.env.close()

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_stats(self):
        stats = {
        }
        return  stats
