from .multiagentenv import MultiAgentEnv
import numpy as np
import torch as th


dx = [-1, 0, 0, 0, 1]
dy = [0, -1, 0, 1, 0]


class MPEGather(MultiAgentEnv):
    def __init__(self, n_agents=3, grid_size=5, episode_limit=50, gamma=.99, neighborhood=0, **kwargs):
        self.grid_size = grid_size
        self.n_agents = n_agents
        self.obs_dim = (n_agents + 1) * 2
        self.state_dim = (n_agents + 1) * 2
        self.n_actions = 5
        self.episode_limit = episode_limit
        self.gamma = gamma
        self.neighborhood = neighborhood
        self.reset()

    def reset(self):
        """Returns initial observations and states."""
        def _sample(N):
            id = np.random.choice(self.grid_size ** 2, N, replace=False)
            x = id // self.grid_size
            y = id % self.grid_size
            return np.stack((x, y), axis=1) / float(self.grid_size - 1)

        self.landmark = _sample(1)
        self.agent_pos = _sample(self.n_agents)

        self.timestep = 0
        self.gamma_reward = 0
        self.rewards = []
        self.state_traj = [self.get_state()]

        return np.array(self.get_obs(), dtype=np.float32), self.get_state()

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.obs_dim

    def get_obs_agent(self, agent_id):
        ret = self.agent_pos[agent_id].tolist()
        for i in range(self.n_agents):
            if i != agent_id:
                ret += (self.agent_pos[i] - self.agent_pos[agent_id]).tolist()
        ret += (self.landmark[0] - self.agent_pos[agent_id]).tolist()
        return np.array(ret, dtype=np.float32)

    def get_obs(self):
        """Returns all agent observations in a list."""
        obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
        return obs

    def get_state(self):
        ret = []
        for i in range(self.n_agents):
            ret += self.agent_pos[i].tolist()
        ret += self.landmark[0].tolist()
        return np.array(ret, dtype=np.float32)

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.state_dim

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def close(self):
        return

    def step(self, _actions):
        """Returns reward, terminated, info."""
        if th.is_tensor(_actions):
            actions = _actions.cpu().numpy()
        else:
            actions = _actions
        actions = actions.tolist()
        for i in range(self.n_agents):
            self.agent_pos[i] += np.array([dx[actions[i]], dy[actions[i]]]) / float(self.grid_size - 1)
        self.agent_pos = np.clip(self.agent_pos, 0., 1.)
        dist = [np.sum(np.abs(self.agent_pos[ai] - self.landmark[0])) for ai in range(self.n_agents)]
        reward = -sum(dist)

        self.timestep += 1
        self.rewards.append(reward)
        self.gamma_reward += (self.gamma ** (self.timestep - 1)) * reward

        done = ((sum(dist) < self.neighborhood / float(self.grid_size - 1) + 1e-6) or (self.timestep == self.episode_limit))

        return reward, done, {} if not done else {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
            'reach': 1. if sum(dist) < self.neighborhood + 1e-6 else 0.,
        }

    def get_stats(self):
        return {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
        }

