from .multiagentenv import MultiAgentEnv
import numpy as np
import torch as th
import time


dx = [-1, 0, 0, 0, 1]
dy = [0, -1, 0, 1, 0]


class MPESpread(MultiAgentEnv):
    def __init__(self, n_agents=3, grid_size=5, episode_limit=50, gamma=.99, neighborhood=1, **kwargs):
        self.grid_size = grid_size
        self.n_agents = n_agents
        self.obs_dim = n_agents * 2 * 2
        self.state_dim = n_agents * 2 * 2
        self.n_actions = 5
        self.episode_limit = episode_limit
        self.gamma = gamma
        self.neighborhood = neighborhood
        self.reset()

    def reset(self):
        """Returns initial observations and states."""
        def _sample():
            id = np.random.choice(self.grid_size ** 2, self.n_agents, replace=False)
            x = id // self.grid_size
            y = id % self.grid_size
            return np.stack((x, y), axis=1) / float(self.grid_size - 1)

        self.landmark = _sample()
        self.agent_pos = _sample()

        self.timestep = 0
        self.gamma_reward = 0
        self.rewards = []
        self.state_traj = [self.get_state()]

        return np.array(self.get_obs(), dtype=np.float32), self.get_state()

    def get_obs_size(self):
        """Returns the size of the observation."""
        return self.obs_dim

    def get_obs_agent(self, agent_id):
        ret = self.agent_pos[agent_id].tolist()
        for i in range(self.n_agents):
            if i != agent_id:
                ret += (self.agent_pos[i] - self.agent_pos[agent_id]).tolist()
        for i in range(self.n_agents):
            ret += (self.landmark[i] - self.agent_pos[agent_id]).tolist()
        return np.array(ret, dtype=np.float32)

    def get_obs(self):
        """Returns all agent observations in a list."""
        obs = [self.get_obs_agent(i) for i in range(self.n_agents)]
        return obs

    def get_state(self):
        ret = []
        for i in range(self.n_agents):
            ret += self.agent_pos[i].tolist()
        for i in range(self.n_agents):
            ret += self.landmark[i].tolist()
        return np.array(ret, dtype=np.float32)

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.state_dim

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return self.get_avail_actions()[agent_id]

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def close(self):
        return

    def step(self, _actions):
        """Returns reward, terminated, info."""
        if th.is_tensor(_actions):
            actions = _actions.cpu().numpy()
        else:
            actions = _actions
        actions = actions.tolist()
        for i in range(self.n_agents):
            self.agent_pos[i] += np.array([dx[actions[i]], dy[actions[i]]]) / float(self.grid_size - 1)
        self.agent_pos = np.clip(self.agent_pos, 0., 1.)
        reach = 0
        collision = 0
        reward = 0.
        for i in range(self.n_agents):
            dist = [np.sum(np.abs(self.agent_pos[ai] - self.landmark[i])) for ai in range(self.n_agents)]
            min_dist = min(dist)
            if min_dist < self.neighborhood / float(self.grid_size - 1) + 1e-6:
                reach += 1
            reward -= min_dist
        for i in range(self.n_agents):
            for j in range(i):
                if np.sum(np.abs(self.agent_pos[i] - self.agent_pos[j])) < 1e-6:
                    collision += 1
                    reward -= 1

        self.timestep += 1
        self.rewards.append(reward)
        self.gamma_reward += (self.gamma ** (self.timestep - 1)) * reward
        self.state_traj.append((reward, actions, self.get_state()))

        '''
        if reach == self.n_agents:
            print('success traj:', self.state_traj)
            print('reward:', self.rewards)
        '''

        done = ((self.timestep == self.episode_limit) or (reach == self.n_agents))

        return reward, done, {} if not done else {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
            'reach': 1. if reach == self.n_agents else 0.,
        }

    def get_stats(self):
        return {
            'avg_reward': np.mean(self.rewards),
            'sum_reward': np.sum(self.rewards),
            'gamma_reward': self.gamma_reward,
        }

    def save_replay(self):
        file_name = f'spread_replay_{time.time()}.txt'
        print('save replay at', file_name)
        print(self.state_traj, file=open(file_name, 'w'))

