import random

import gfootball.env as football_env
from gym import spaces
from debug import debug_print
import numpy as np


class FootballEnv(object):
    '''Wrapper to make Google Research Football environment compatible'''

    def __init__(self, args):
        self.num_agents = args.num_agents
        self.rnum_agents = args.rnum_agents
        self.scenario_name = args.scenario_name
        
        # make env
        if not (args.use_render and args.save_videos):
            self.env = football_env.create_environment(
                env_name=args.scenario_name,
                stacked=args.use_stacked_frames,
                representation=args.representation,
                rewards=args.rewards,
                number_of_left_players_agent_controls=args.rnum_agents,
                number_of_right_players_agent_controls=0,
                channel_dimensions=(args.smm_width, args.smm_height),
                render=(args.use_render and args.save_gifs)
            )
            # print(args.scenario_name)
        else:
            # render env and save videos
            self.env = football_env.create_environment(
                env_name=args.scenario_name,
                stacked=args.use_stacked_frames,
                representation=args.representation,
                rewards=args.rewards,
                number_of_left_players_agent_controls=args.rnum_agents,
                number_of_right_players_agent_controls=0,
                channel_dimensions=(args.smm_width, args.smm_height),
                # video related params
                write_full_episode_dumps=True,
                render=True,
                write_video=True,
                dump_frequency=1,
                logdir=args.video_dir
            )
            
        self.max_steps = self.env.unwrapped.observation()[0]["steps_left"]
        self.remove_redundancy = args.remove_redundancy
        self.zero_feature = args.zero_feature
        self.share_reward = args.share_reward
        self.action_space = []
        self.observation_space = []
        self.share_observation_space = []
        # print(self.env.observation_space)
        
        self.observation_space = [[self.env.observation_space.shape[1] * self.rnum_agents]]
        # .append(spaces.Box(
        #     low=self.env.observation_space.low[0],
        #     high=self.env.observation_space.high[0],
        #     shape=self.env.observation_space.shape[1:] * self.rnum_agents,
        #     dtype=self.env.observation_space.dtype
        # ))
        self.share_observation_space = [[self.env.observation_space.shape[1] * self.rnum_agents]]
        # .append(spaces.Box(
        #     low=self.env.observation_space.low[0],
        #     high=self.env.observation_space.high[0],
        #     shape=self.env.observation_space.shape[1:] * self.rnum_agents,
        #     dtype=self.env.observation_space.dtype
        # ))
        for idx in range(self.rnum_agents):
            self.action_space.append(spaces.Discrete(
                n=self.env.action_space.nvec[idx]
            ))
        self.action_space = [self.action_space]
        debug_print(self.action_space, self.observation_space)
        # print(self.observation_space)


    def reset(self):
        obs = self.env.reset()
        obs = self._obs_wrapper(obs)
        # print(obs.shape)
        return [obs[0].flatten()]
        return [obs[0,0]]

    def step(self, action):
        # print(action)
        obs, reward, done, info = self.env.step(action[0])
        obs = self._obs_wrapper(obs)
        reward = reward.reshape(self.rnum_agents, 1)
        if self.share_reward:
            global_reward = np.sum(reward)
            reward = [[global_reward]] * self.rnum_agents

        done = np.array([done] * self.num_agents)
        info = self._info_wrapper(info)
        # reward = [reward[0]*self.rnum_agents]
        # print(reward, self.share_reward)
        # print(obs[0, 0]-obs[0, 1])
        # [obs[0].flatten()]
        return [obs[0].flatten()], reward, done, info
        return [obs[0,0]], reward, done, info

    def seed(self, seed=None):
        if seed is None:
            random.seed(1)
        else:
            random.seed(seed)

    def close(self):
        self.env.close()

    def _obs_wrapper(self, obs):
        if self.num_agents == 1:
            return obs[np.newaxis, :]
        else:
            return obs

    def _info_wrapper(self, info):
        state = self.env.unwrapped.observation()
        info.update(state[0])
        info["max_steps"] = self.max_steps
        info["active"] = np.array([state[i]["active"] for i in range(self.num_agents)])
        info["designated"] = np.array([state[i]["designated"] for i in range(self.num_agents)])
        info["sticky_actions"] = np.stack([state[i]["sticky_actions"] for i in range(self.num_agents)])
        return info
