from .multiagentenv import MultiAgentEnv
import numpy as np
import torch as th


dx = [-1, 0, 0, 0, 1]
dy = [0, -1, 0, 1, 0]


class MPEFormation(MultiAgentEnv):
    def __init__(self, n_agents=4, grid_size=5, episode_limit=50, gamma=.99, alpha_gap=.3, dist_gap=.1, **kwargs):
        self.grid_size = grid_size
        self.n_agents = n_agents
        self.obs_dim = (n_agents + 1) * 2
        self.state_dim = (n_agents + 1) * 2
        self.n_actions = 5
        self.episode_limit = episode_limit
        self.gamma = gamma
        self.alpha_gap = alpha_gap
        self.dist_gap = dist_gap
        self.reset()

    def reset(self):
        """Returns initial observations and states."""
        def _sample(N):
            id = np.random.choice(self.grid_size ** 2, N, replace=False)
            x = id // self.grid_size
            y = id % self.grid_size
            return np.stack((x, y), axis=1) / float(self.grid_size - 1)

        self.landmark = _sample(1)
        self.agent_pos = _sample(self.n_agents)

        self.timestep = 0
        self.gamma_reward = 0
        self.rewards = []
        self.state_traj = [self.get_state()]

        return np.array(self.get_obs(), dtype=np.float32), self.get_state()

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.obs_dim

    def get_obs_agent(self, agent_id):
        ret = self.agent_pos[agent_id].tolist()
        for i in range(self.n_agents):
            if i != agent_id:
                ret += (self.agent_pos[i] - self.agent_pos[agent_id]).tolist()
        ret += (self.landmark[0] - self.agent_pos[agent_id]).tolist()
        return np.array(ret, dtype=np.float32)

    def get_obs(self):
        """Returns all agent observations in a list."""
        obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
        return obs

    def get_state(self):
        ret = []
        for i in range(self.n_agents):
            ret += self.agent_pos[i].tolist()
        ret += self.landmark[0].tolist()
        return np.array(ret, dtype=np.float32)

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.state_dim

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def close(self):
        return

    def step(self, _actions):
        """Returns reward, terminated, info."""
        if th.is_tensor(_actions):
            actions = _actions.cpu().numpy()
        else:
            actions = _actions
        actions = actions.tolist()
        for i in range(self.n_agents):
            self.agent_pos[i] += np.array([dx[actions[i]], dy[actions[i]]]) / float(self.grid_size - 1)
        self.agent_pos = np.clip(self.agent_pos, 0., 1.)
        raw_relative_pose = self.agent_pos - self.landmark
        relative_pose = []
        for i in range(self.n_agents):
            def stat(x, y):
                ang = np.arctan2(y, x)
                dist = ((x ** 2) + (y ** 2)) ** .5
                return ang, dist
            relative_pose.append(stat(raw_relative_pose[i][0], raw_relative_pose[i][1]))
        relative_pose = sorted(relative_pose)
        alpha = []
        dist = []
        for i in range(self.n_agents):
            alpha.append(relative_pose[i][0] - relative_pose[0][0] - 2 * np.pi * i / float(self.n_agents))
            dist.append(relative_pose[i][1])
        alpha = np.array(alpha)
        dist = np.array(dist)
        alpha -= alpha.mean()
        dist -= dist.mean()

        reward = -np.sum(np.abs(alpha)) - np.sum(np.abs(dist))

        self.timestep += 1
        self.rewards.append(reward)
        self.gamma_reward += (self.gamma ** (self.timestep - 1)) * reward

        done = (((np.max(np.abs(alpha)) < self.alpha_gap) and (np.max(np.abs(dist)) < self.dist_gap)) or
                (self.timestep == self.episode_limit))

        '''
        if ((np.max(np.abs(alpha)) < self.alpha_gap) and (np.max(np.abs(dist)) < self.dist_gap)):
            print(self.get_state())
            print(relative_pose)
            print(alpha, dist)
        '''

        return reward, done, {} if not done else {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
            'avg_alpha': np.mean(np.abs(alpha)),
            'avg_dist': np.mean(np.abs(dist)),
            'reach': 1. if ((np.max(np.abs(alpha)) < self.alpha_gap) and (np.max(np.abs(dist)) < self.dist_gap)) else 0.,
        }

    def get_stats(self):
        return {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
        }

