from functools import partial
import numpy as np
from . import MultiAgentEnv
from pettingzoo.sisl import multiwalker_v9
from gym.spaces import Box

class MultiWalker(MultiAgentEnv):
    def __init__(self, batch_size=None, **kwargs):
        super().__init__(batch_size, **kwargs)
        self.scenario = kwargs["env_args"]["scenario_name"] # e.g. Ant-v2
        self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3'

        self.n_agents = int(self.agent_conf.split('x')[0])
        self.n_actions = 4
        # load scenario from script
        self.episode_limit = self.args.episode_limit

        self.env_version = kwargs["env_args"].get("env_version", 2)
        self.kaz = multiwalker_v9.env(n_walkers = self.n_agents, fall_reward = 0.0, terminate_reward = 0.0, terminate_on_fall = False)
        self.kaz.reset()

        acdims = [4]
        self.action_space = tuple([Box(-1. * np.ones((6,), dtype = np.float32), 1. * np.ones((6,), dtype = np.float32)) for a in range(self.n_agents)])


        self.agent_names = self.kaz.agents
        self.reset()
        self.obs_size = self.get_obs_size()

    def step(self, actions):
        actions_disc = []
        actions_copy = np.copy(actions)

        if np.any(list(self.kaz.dones.values())):
            return 0.0, False, {}
        
        pre_reward = np.sum(list(self.kaz.rewards.values()))
        for i in range(self.n_agents):
            if not self.kaz.dones[self.agent_names[i]]:
                self.kaz.step(actions_copy[i])
            else:
                self.kaz.step()
        post_reward = np.sum(list(self.kaz.rewards.values()))

        reward_n = float(post_reward)
        
        #flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
        self.steps += 1
        done_n = self.steps >= self.episode_limit
        
        info = {}

        if done_n:
            if self.steps < self.episode_limit:
                info["episode_limit"] = False   # the next state will be masked out
            else:
                info["episode_limit"] = True    # the next state will not be masked out

        return reward_n, done_n, info

    #TODO return numpy array or list?

    def get_obs(self):
        """ Returns all agent observat3ions in a list """
        obs_n = []
        for a in range(self.n_agents):
            obs_n.append(self.get_obs_agent(a))
        return obs_n

    def get_obs_agent(self, agent_id):
        raw_obs = self.kaz.observe(self.agent_names[agent_id])
        obs = np.concatenate((raw_obs[0:14], raw_obs[28:]))
        return obs

    def get_obs_size(self):
        """ Returns the shape of the observation """
        return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])

    def get_state(self, team=None):
        state = self.get_obs()
        state = np.array(state)
        return list(state.reshape(-1))

    def get_state_size(self):
        """ Returns the shape of the state"""
        return len(self.get_state())

    def get_avail_actions(self): # all actions are always available
        return np.ones(shape=(self.n_agents, self.n_actions,))

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        return np.ones(shape=(self.n_actions,))

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather
        # return self.env.action_space.shape[0]

    def get_stats(self):
        return {}

    def get_agg_stats(self, stats):
        return {}

    def reset(self, **kwargs):
        """ Returns initial observations and states"""
        self.steps = 0
        self.kaz.reset()
        return self.get_obs()

    def render(self, **kwargs):
        pass

    def close(self):
        raise NotImplementedError

    def seed(self, args):
        pass

    def get_env_info(self):

        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "action_spaces": self.action_space,
                    "actions_dtype": np.float32,
                    "normalise_actions": False
                    }
        return env_info
