# noqa: D212, D415
"""
# Simple Push

```{figure} mpe_simple_push.gif
:width: 140px
:name: simple_push
```

This environment is part of the <a href='..'>MPE environments</a>. Please read that page first for general information.

| Import             | `from pettingzoo.mpe import simple_push_v3` |
|--------------------|---------------------------------------------|
| Actions            | Discrete/Continuous                         |
| Parallel API       | Yes                                         |
| Manual Control     | No                                          |
| Agents             | `agents= [adversary_0, agent_0]`            |
| Agents             | 2                                           |
| Action Shape       | (5)                                         |
| Action Values      | Discrete(5)/Box(0.0, 1.0, (5,))             |
| Observation Shape  | (8),(19)                                    |
| Observation Values | (-inf,inf)                                  |
| State Shape        | (27,)                                       |
| State Values       | (-inf,inf)                                  |


This environment has 1 good agent, 1 adversary, and 1 landmark. The good agent is rewarded based on the distance to the landmark. The adversary is rewarded if it is close to the landmark, and if the agent is far from the landmark (the difference of the distances). Thus the adversary must learn to
push the good agent away from the landmark.

Agent observation space: `[self_vel, goal_rel_position, goal_landmark_id, all_landmark_rel_positions, landmark_ids, other_agent_rel_positions]`

Adversary observation space: `[self_vel, all_landmark_rel_positions, other_agent_rel_positions]`

Agent action space: `[no_action, move_left, move_right, move_down, move_up]`

Adversary action space: `[no_action, move_left, move_right, move_down, move_up]`

### Arguments

``` python
simple_push_v3.env(max_cycles=25, continuous_actions=False)
```



`max_cycles`:  number of frames (a step for each agent) until game terminates

"""

import numpy as np
from gymnasium.utils import EzPickle

from pettingzoo.discrete_mpe._mpe_utils.core import Agent, Landmark, World
from pettingzoo.discrete_mpe._mpe_utils.scenario import BaseScenario
from pettingzoo.discrete_mpe._mpe_utils.simple_env import SimpleEnv, make_env
from pettingzoo.utils.conversions import parallel_wrapper_fn


class raw_env(SimpleEnv, EzPickle):
    def __init__(self, max_cycles=25, continuous_actions=False, render_mode=None):
        EzPickle.__init__(
            self,
            max_cycles=max_cycles,
            continuous_actions=continuous_actions,
            render_mode=render_mode,
        )
        scenario = Scenario()
        world = scenario.make_world()
        SimpleEnv.__init__(
            self,
            scenario=scenario,
            world=world,
            render_mode=render_mode,
            max_cycles=max_cycles,
            continuous_actions=continuous_actions,
        )
        self.metadata["name"] = "simple_push_v3"


env = make_env(raw_env)
parallel_env = parallel_wrapper_fn(env)


class Scenario(BaseScenario):
    def make_world(self):
        world = World()
        # set any world properties first
        world.dim_c = 2
        num_agents = 2
        num_adversaries = 1
        num_landmarks = 2
        # add agents
        world.agents = [Agent() for i in range(num_agents)]
        for i, agent in enumerate(world.agents):
            agent.adversary = True if i < num_adversaries else False
            base_name = "adversary" if agent.adversary else "agent"
            base_index = i if i < num_adversaries else i - num_adversaries
            agent.name = f"{base_name}_{base_index}"
            agent.collide = True
            agent.silent = True
        # add landmarks
        world.landmarks = [Landmark() for i in range(num_landmarks)]
        for i, landmark in enumerate(world.landmarks):
            landmark.name = "landmark %d" % i
            landmark.collide = False
            landmark.movable = False
        return world

    def reset_world(self, world, np_random):
        # random properties for landmarks
        for i, landmark in enumerate(world.landmarks):
            landmark.color = np.array([0.1, 0.1, 0.1])
            landmark.color[i + 1] += 0.8
            landmark.index = i
        # set goal landmark
        goal = np_random.choice(world.landmarks)
        for i, agent in enumerate(world.agents):
            agent.goal_a = goal
            agent.color = np.array([0.25, 0.25, 0.25])
            if agent.adversary:
                agent.color = np.array([0.75, 0.25, 0.25])
            else:
                j = goal.index
                agent.color[j + 1] += 0.5
        # set random initial states
        for agent in world.agents:
            agent.state.p_pos = np_random.uniform(-1, +1, world.dim_p)
            agent.state.p_vel = np.zeros(world.dim_p)
            agent.state.c = np.zeros(world.dim_c)
        for i, landmark in enumerate(world.landmarks):
            landmark.state.p_pos = np_random.uniform(-1, +1, world.dim_p)
            landmark.state.p_vel = np.zeros(world.dim_p)

    def reward(self, agent, world):
        # Agents are rewarded based on minimum agent distance to each landmark
        return (
            self.adversary_reward(agent, world)
            # if agent.adversary
            # else self.agent_reward(agent, world)
        )

    def agent_reward(self, agent, world):
        # the distance to the goal
        return -np.sqrt(np.sum(np.square(agent.state.p_pos - agent.goal_a.state.p_pos)))

    def adversary_reward(self, agent, world):
        # keep the nearest good agents away from the goal
        agent_dist = [
            np.sqrt(np.sum(np.square(a.state.p_pos - a.goal_a.state.p_pos)))
            for a in world.agents
            if not (a.adversary == agent.adversary)
        ]





        pos_rew = min(agent_dist)
        # nearest_agent = world.good_agents[np.argmin(agent_dist)]
        # neg_rew = np.sqrt(np.sum(np.square(nearest_agent.state.p_pos - agent.state.p_pos)))
        neg_rew = np.sqrt(
            np.sum(np.square(agent.goal_a.state.p_pos - agent.state.p_pos))
        )
        # neg_rew = sum([np.sqrt(np.sum(np.square(a.state.p_pos - agent.state.p_pos))) for a in world.good_agents])
        return pos_rew - neg_rew

    def observation(self, agent, world):
        # get positions of all entities in this agent's reference frame
        entity_pos = []
        for entity in world.landmarks:  # world.entities:
            entity_pos.append(entity.state.p_pos - agent.state.p_pos)
        # entity colors
        entity_color = []
        for entity in world.landmarks:  # world.entities:
            entity_color.append(entity.color)
        # communication of all other agents
        comm = []
        other_pos = []
        for other in world.agents:
            if other is agent:
                continue
            comm.append(other.state.c)
            other_pos.append(other.state.p_pos - agent.state.p_pos)
        # if not agent.adversary:
        return np.concatenate(
            [agent.state.p_vel]
            + [agent.goal_a.state.p_pos - agent.state.p_pos]
            + [agent.color]
            + entity_pos
            + entity_color
            + other_pos
        )
        # else:
        #     return np.concatenate([agent.state.p_vel] + entity_pos + other_pos)


    def all_state(self, world):
        entity_pos = []
        entity_color = []
        for entity in world.landmarks:  # world.entities:
            entity_pos.append(entity.state.p_pos)
            entity_color.append(entity.color)
        # communication of all other agents
        agent_vel = []
        agent_pos = []
        agent_goal = []
        agent_color = []
        for agent in world.agents:
            agent_vel.append(agent.state.p_vel)
            agent_pos.append(agent.state.p_pos)
            agent_goal.append(agent.goal_a.state.p_pos)
            agent_color.append(agent.color)
        # if not agent.adversary:
        return np.concatenate(
            agent_vel
            + agent_pos
            + agent_goal
            + agent_color
            + entity_pos
            + entity_color
        )
        
    def agent_state(self, world):
        agent_vel = []
        agent_pos = []
        for agent in world.agents:
            agent_vel.append(agent.state.p_vel)
            agent_pos.append(agent.state.p_pos)
        # if not agent.adversary:
        return np.concatenate(
            agent_vel
            + agent_pos
        )
        
        
        
        
        
    def all_state_add_noise(self, sigma, world):
        noise = (
                    np.random.randn(*self.all_state(world).shape) * sigma
                    if sigma
                    else 0.0
                )
        i = 0
        for agent in world.agents:
            agent.state.p_vel += noise[i:i+2]
            i+=2
        for agent in world.agents:
            agent.state.p_pos += noise[i+2]
            i+=2
        for agent in world.agents:
            agent.goal_a.state.p_pos += noise[i:i+2]
            i+=2
        for agent in world.agents:
            agent.color += noise[i]
            i+=1
        for entity in world.landmarks:
            entity.state.p_pos += noise[i+2]
            i+=2
        for entity in world.landmarks:
            entity.color += noise[i]
            i+=1
            
    def agent_state_add_noise(self, sigma, world):
        noise = (
                    np.random.randn(*self.agent_state(world).shape) * sigma
                    if sigma
                    else 0.0
                )
        i = 0
        for agent in world.agents:
            agent.state.p_vel += noise[i:i+2]
            i+=2
        for agent in world.agents:
            agent.state.p_pos += noise[i:i+2]
            i+=2

    def get_dis_state(self):
        agent_vel = []
        agent_pos = []
        for agent in world.agents:
            agent_vel.append(agent.state.p_vel)
            agent_pos.append(agent.state.p_pos)
        # if not agent.adversary:
        return np.concatenate(
            agent_vel
            + agent_pos
        )

    def get_goal_pos(self):
        for i, agent in enumerate(world.agents):
            return agent.goal_a.state.p_pos