from pettingzoo.utils.wrappers import BaseParallelWraper
from pettingzoo.mpe import simple_spread_v2

import numpy as np
from gymnasium.spaces import Box

class Spread_Indp(BaseParallelWraper):
    def __init__(self, render = False, param = 10):
        self.env = simple_spread_v2.parallel_env(render_mode = "human" if render else None, local_ratio = 1, max_cycles = 100)
        self.agents = [f"agent_{i}" for i in range(3)]
        self.param = param
        super().__init__(self.env)
    
    def observation_space(self, agent):
        return Box(float("-inf"), float("inf"), (14,))

    def reset(self, seed=None, return_info=False, options=None):
        obs = super().reset(seed)
        for a in obs.keys():
            removed = {
                "0" : [6,7,8,9],
                "1" : [4,5,8,9], 
                "2" : [4,5,6,7]
            }[a[-1]]
            obs[a] = np.delete(obs[a], removed)
        self.obs = obs
        return obs, {}

    def step(self, action):
        # bound field            
        for a in action.keys():
            if self.obs[a][2] < -1 and action[a] == 1:
                action[a] = 0
            elif self.obs[a][2] > 1 and action[a] == 2:
                action[a] = 0
            elif self.obs[a][3] < -1 and action[a] == 3:
                action[a] = 0
            elif self.obs[a][3] > 1 and action[a] == 4:
                action[a] = 0
        obs, rew, done, trunc, info = super().step(action)
        for a in rew.keys():
            rew[a] += 1
            rew[a] *= self.param
            #get distance to landmark
            #get landmark
            mark, removed  = {
                "0" : (4, [6,7,8,9]),
                "1" : (6, [4,5,8,9]), 
                "2" : (8, [4,5,6,7])
            }[a[-1]]
            rew[a] -= sum(obs[a][mark:mark+2]**2)
            obs[a] = np.delete(obs[a], removed)
        self.obs = obs
        return obs, rew, done, trunc, info
    
    def render(self):
        return super().render()