from gym_multigrid.multigrid import *


class SoccerGameEnv(MultiGridEnv):
    """
    Environment in which the agents have to fetch the balls and drop them in their respective goals
    """

    def __init__(
            self,
            size=10,
            view_size=3,
            width=None,
            height=None,
            goal_pst=[],
            goal_index=[],
            num_balls=[],
            agents_index=[],
            balls_index=[],
            zero_sum=False,
    ):
        self.num_balls = num_balls
        self.goal_pst = goal_pst
        self.goal_index = goal_index
        self.balls_index = balls_index
        self.zero_sum = zero_sum

        self.world = World

        agents = []
        for i in agents_index:
            agents.append(Agent(self.world, i, view_size=view_size))

        super().__init__(
            grid_size=size,
            width=width,
            height=height,
            max_steps=512,
            see_through_walls=False,
            agents=agents,
            agent_view_size=view_size
        )

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)

        # Generate the surrounding walls
        self.grid.horz_wall(self.world, 0, 0)
        self.grid.horz_wall(self.world, 0, height - 1)
        self.grid.vert_wall(self.world, 0, 0)
        self.grid.vert_wall(self.world, width - 1, 0)

        for i in range(len(self.goal_pst)):
            self.place_obj(ObjectGoal(self.world, self.goal_index[i], 'ball'), top=self.goal_pst[i], size=[1, 1])

        for number, index in zip(self.num_balls, self.balls_index):
            for i in range(number):
                self.place_obj(Ball(self.world, index))

        # Randomize the player start position and orientation
        for a in self.agents:
            self.place_agent(a)

    def _reward(self, i, rewards, reward=1):
        carrying_agent_index = self.agents[i].index
        if carrying_agent_index == i:
            reward_value = 10.0
            for k in range(len(rewards)):
                rewards[k] += reward_value
        else:
            penalty_value = 5.0
            for k in range(len(rewards)):
                rewards[k] -= penalty_value

    def _handle_pickup(self, i, rewards, fwd_pos, fwd_cell):
        if fwd_cell:
            if fwd_cell.can_pickup():
                if self.agents[i].carrying is None:
                    self.agents[i].carrying = fwd_cell
                    self.agents[i].carrying.cur_pos = np.array([-1, -1])
                    self.grid.set(*fwd_pos, None)
            elif fwd_cell.type == 'agent':
                if fwd_cell.carrying:
                    if self.agents[i].carrying is None:
                        self.agents[i].carrying = fwd_cell.carrying
                        fwd_cell.carrying = None

    def _handle_drop(self, i, rewards, fwd_pos, fwd_cell):
        if self.agents[i].carrying:
            if fwd_cell:
                if fwd_cell.type == 'objgoal':
                    self._reward(fwd_cell.index, rewards, fwd_cell.reward)
                    ball_index = self.agents[i].carrying.index
                    self.agents[i].carrying = None
                    self.place_obj(Ball(self.world, ball_index))
                elif fwd_cell.type == 'agent':
                    if fwd_cell.carrying is None:
                        fwd_cell.carrying = self.agents[i].carrying
                        self.agents[i].carrying = None
            else:
                self.grid.set(*fwd_pos, self.agents[i].carrying)
                self.agents[i].carrying.cur_pos = fwd_pos
                self.agents[i].carrying = None

    def step(self, actions):
        obs, rewards, done, info = MultiGridEnv.step(self, actions)

        # 初始化共享奖励
        shared_reward = 0.0

        # 保存所有agent的携带状态（只取第一个持球者）
        ball_pos = None
        holder = None
        for agent in self.agents:
            if agent.carrying:
                holder = agent
                ball_pos = agent.pos
                break

        if holder is not None:
            # 获取对应球门
            team_idx = holder.index
            goal_pos = self.goal_pst[team_idx - 1]

            # 上一次距离
            if not hasattr(holder, "last_dist_to_goal"):
                holder.last_dist_to_goal = np.linalg.norm(np.array(ball_pos) - np.array(goal_pos))

            now_dist = np.linalg.norm(np.array(ball_pos) - np.array(goal_pos))
            last_dist = holder.last_dist_to_goal

            # 距离衰减惩罚（最大为 -3.0）
            shared_reward -= min(0.02 * now_dist, 3.0)

            # 推进球门的奖励
            shared_reward += 0.2 * (last_dist - now_dist)

            # 更新缓存
            holder.last_dist_to_goal = now_dist

        # 成功拿球 +2（用 actions 和 carrying 判断）
        for i, agent in enumerate(self.agents):
            if agent.carrying and not hasattr(agent, 'has_picked'):
                shared_reward += 2.0
                for a in self.agents:
                    setattr(a, 'has_picked', True)
                break
        else:
            for a in self.agents:
                if hasattr(a, 'has_picked'):
                    delattr(a, 'has_picked')

        # 每个智能体共享奖励
        for i in range(len(rewards)):
            rewards[i] += shared_reward

        return obs, rewards, done, info


class SoccerGame6HEnv15x15N3(SoccerGameEnv):
    def __init__(self):
        super().__init__(
            size=None,
            height=15,
            width=15,
            goal_pst=[[1, 6], [13, 6], [7, 13]],
            goal_index=[1, 2, 3],
            num_balls=[4],
            agents_index=[1, 1, 1, 2, 2, 2, 3, 3, 3],
            balls_index=[0, 0],
            zero_sum=False
        )