import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate, BadLettuce
import random
import time
import torch
import os
import torch
import torch.nn as nn
from collections import deque




# ====== 全局随机种子 ======
SEED = 42  # 你可以修改这个数字来改变随机性

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)




EVENT_TYPES = [
    "pass_plate",
    "pass_lettuce",
    "pass_chopped_lettuce",
    "pass_plated_lettuce",
    "pass_dirty_lettuce",
    "pass_chopped_dirty_lettuce",
    "pass_plated_dirty_lettuce",
    "make_clean_salad_alone",
    "make_dirty_salad_alone",
    "only_take_lettuce",
    "only_chop_lettuce",
    "only_plate_salad"
]



class OneHotEventReward:
    def __init__(self, active_event_type: str, reward_value: float = 1.0):
        """
        :param active_event_type: str, one of the EVENT_TYPES
        :param reward_value: float, reward to give when that event occurs
        """
        assert active_event_type in EVENT_TYPES, f"Invalid event type: {active_event_type}"
        self.active_event = active_event_type
        self.reward_value = reward_value

    def generate_action(self, env) -> float:
        # if self._event_triggered(env, self.active_event) and first_time == True:
        action = self._event_triggered(env, self.active_event)
            # first_time = False
            # return self.reward_value, first_time
        return action

    def _event_triggered(self, env, event_type: str) -> bool:
        if event_type == "pass_plate":
            return self._is_passing_plate(env)
        elif event_type == "pass_lettuce":
            return self._is_passing_lettuce(env)
        elif event_type == "pass_chopped_lettuce":
            return self._is_passing_chopped_lettuce(env)
        elif event_type == "pass_plated_lettuce":
            return self._is_passing_plated_lettuce(env)
        elif event_type == "pass_dirty_lettuce":
            return self._is_passing_dirty_lettuce(env)
        elif event_type == "pass_chopped_dirty_lettuce":
            return self._is_passing_chopped_dirty_lettuce(env)
        elif event_type == "pass_plated_dirty_lettuce":
            return self._is_passing_plated_dirty_lettuce(env)
        elif event_type == "make_clean_salad_alone":
            return self._is_making_clean_salad_alone(env)
        elif event_type == "make_dirty_salad_alone":
            return self._is_making_dirty_salad_alone(env)
        elif event_type == "only_take_lettuce":
            return self._is_taking_but_not_chopping(env)
        elif event_type == "only_chop_lettuce":
            return self._is_chopping_but_not_plating(env)
        elif event_type == "only_plate_salad":
            return self._is_plating_but_not_serving(env)
        else:
            return False
        


    def _is_passing_plate(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and not agent.holding.containing:
            best_action = "go to counter"
            
        return macroActionDict[best_action]


    def _is_passing_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]
        


    def _is_passing_chopped_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]

        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]

        # best_action = "stay"
        # agent = env.agent[1]

        # knife = env.knife[1]
        # if not agent.holding and not knife.holding:
        #     best_action = "get plate 2"
        # if agent.holding and isinstance(agent.holding, Plate):
        #     best_action = "go to knife 2"

        # return macroActionDict[best_action]
    



    def _is_passing_plated_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"


        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], Lettuce):
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_chopped_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]

        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_plated_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]



        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1 and not agent.holding and not knife0.holding.chopped) or (knife1.holding and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1 and not agent.holding and not knife1.holding.chopped):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"


        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], BadLettuce):
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_making_clean_salad_alone(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"


        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], Lettuce):
            best_action = "deliver 2"
        return macroActionDict[best_action]



    def _is_making_dirty_salad_alone(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"


        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], BadLettuce):
            best_action = "deliver 2"
        return macroActionDict[best_action]










# 增加 ABI 模型保存数据的函数
def save_dataset_to_disk(data_buffer, filepath="abi_training_data.pt"):
    torch.save(data_buffer, filepath)
    print(f"✅ Data buffer saved to {filepath}")

def load_dataset_from_disk(filepath="abi_training_data.pt"):
    return torch.load(filepath)




class LightweightTransformerABIModel(nn.Module):
    def __init__(self, state_dim, seq_len=15, hidden_dim=32, nhead=2, num_layers=1):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'  # 更轻量
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling
        self.attention_vector = nn.Parameter(torch.randn(hidden_dim))  # 可学习注意力向量

        self.fc_ability = nn.Linear(hidden_dim, 1)
        self.fc_benevolence = nn.Linear(hidden_dim, 1)
        self.fc_integrity = nn.Linear(hidden_dim, 1)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.data_buffer = []

    def forward(self, x):
        # x: [batch, seq_len, state_dim]
        x = self.input_proj(x)           # [B, T, H]
        h = self.transformer(x)          # [B, T, H]

        # Attention pooling
        attn_weights = torch.softmax(torch.matmul(h, self.attention_vector), dim=1)  # [B, T]
        pooled = torch.sum(h * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        ability     = torch.sigmoid(self.fc_ability(pooled))
        benevolence = torch.sigmoid(self.fc_benevolence(pooled))
        integrity   = torch.sigmoid(self.fc_integrity(pooled))
        return ability, benevolence, integrity

    def store_data(self, state_history, ability, benevolence, integrity):
        self.data_buffer.append((state_history, ability, benevolence, integrity))


    def train_step(self):
        # print(self.data_buffer)
        if not self.data_buffer:
            # print('here')
            return None

        states, abilitys, benevolences, integritys = zip(*self.data_buffer)
        states = torch.tensor(states, dtype=torch.float32)
        abilitys = torch.tensor(abilitys, dtype=torch.float32).unsqueeze(-1)
        benevolences = torch.tensor(benevolences, dtype=torch.float32).unsqueeze(-1)
        integritys = torch.tensor(integritys, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_ability, pred_benevolence, pred_integrity = self.forward(states)

        # 分别计算各自 loss
        loss_ability = self.criterion(pred_ability, abilitys)
        loss_benevolence = self.criterion(pred_benevolence, benevolences)
        loss_integrity = self.criterion(pred_integrity, integritys)

        total_loss = loss_ability + loss_benevolence + loss_integrity
        total_loss.backward()
        self.optimizer.step()

        # self.data_buffer = []

        # 你可以返回一个 dict，方便调试记录
        return {
            'total_loss': total_loss.item(),
            'loss_ability': loss_ability.item(),
            'loss_benevolence': loss_benevolence.item(),
            'loss_integrity': loss_integrity.item()
        }



    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {path}")








class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0

        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]

        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        # Save and plot every save_freq steps
        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            # Plot moving average of last 100 episode rewards
            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) / (i - max(0, i - window + 1) + 1)
                              for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()

        return True








class SingleAgentWrapper_Without_Latent(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper_Without_Latent, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]


        """下面代码是进行【合作】的reward shaping，以及对inconsistency进行惩罚的部分，暂时注释掉"""
        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions = self.env._computeLowLevelActions(actions)

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info







class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, partner_env, agent_index):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None

        self.partner_position_history = []


        self.rewarder = None




    def set_partner_idx(self, rule_name):
        self.rule_name = rule_name


    def reset(self):
        self.obs = self.env.reset()

        self.partner_position_history = []

        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])


        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.rewarder = OneHotEventReward(active_event_type=self.rule_name, reward_value=400)

        return self.obs[self.agent_index]
    


    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.rewarder.generate_action(self.env)

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action



        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)

        if self.env.macroActionName[real_execute_macro_actions[1]] == "get badlettuce" and self.env.badlettuce[0].x == 4:
            primary_actions[1 - self.agent_index] = 4

        # if (self.env.macroActionName[real_execute_macro_actions[1]] == "get lettuce 1" and self.env.lettuce[0].y == 7) or (self.env.macroActionName[real_execute_macro_actions[1]] == "get lettuce 2" and self.env.lettuce[1].y == 7):
        #     primary_actions[1 - self.agent_index] = 4

        # if (self.env.macroActionName[real_execute_macro_actions[1]] == "get plate 1" and self.env.plate[0].y == 7) or (self.env.macroActionName[real_execute_macro_actions[1]] == "get plate 2" and self.env.plate[1].y == 7):
        #     primary_actions[1 - self.agent_index] = 4    



        if len(self.partner_position_history) > 5 and self.partner_position_history[-5] == self.partner_position_history[-4] and self.partner_position_history[-4] == self.partner_position_history[-3] and self.partner_position_history[-3] == self.partner_position_history[-2] and self.partner_position_history[-2] == self.partner_position_history[-1]:
            
            to_choose = []

            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y + 1]] == "space":
                to_choose.append(0)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x + 1][self.env.agent[1].y]] == "space":
                to_choose.append(1)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y - 1]] == "space":
                to_choose.append(2)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x - 1][self.env.agent[1].y]] == "space":
                to_choose.append(3)

            primary_actions[1 - self.agent_index] = random.choice(to_choose)
        

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()


        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])



        



        return self.obs[self.agent_index], rewards[self.agent_index]+rewards[1-self.agent_index]+benevolence_reward, dones, info
    




ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices = []
    distances = []
    
    # 收集可达的索引和对应距离
    if can_reach_1 != 4:
        reachable_indices.append(0)
        distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1)
        distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2)
        distances.append(distance_3)
    
    if not reachable_indices:
        return False  # 没有可达的位置
    
    # 如果只有一个可达位置，直接返回该索引
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    
    # 否则，返回距离最小的索引
    min_distance_index = reachable_indices[distances.index(min(distances))]
    return min_distance_index


def intelligently_find_item_number(env, agent_item, raw_name):

    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)


        best_action = "stay"

        min_distance_index = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)

        if min_distance_index == 0:
            best_action = raw_name + " 1"
        if min_distance_index == 1:
            best_action = raw_name + " 2"
        if min_distance_index == 2:
            best_action = "get dirty plate"
        # print('----', best_action)
        return best_action

    

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4:
        best_action = raw_name + " 2"

    if can_reach_1 != 4 and can_reach_2 == 4:
        best_action = raw_name + " 1"

    if can_reach_1 != 4 and can_reach_2 != 4:
        if distance_1 <= distance_2:
            best_action = raw_name + " 1"

        else:
            best_action = raw_name + " 2"
    
    return best_action



def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 50
    return env.reward



def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):

    agent_item = env.agent[1]
    human_agent = env.agent[0]

    counter1_x = 12
    counter1_y = 7

    counter2_x = 13
    counter2_y = 7


    counter1 = ITEMNAME[env.map[counter1_x][counter1_y]]
    counter2 = ITEMNAME[env.map[counter2_x][counter2_y]]


    reward_shaping_bonus = 0

    total_reward_bonus = 0

    reward_bonus_left = 0
    reward_bonus_right = 0

    """右侧high benevolence"""
    counters = [counter1, counter2]
    # if any(counter in ("lettuce") for counter in counters):
    #     best_action = intelligently_find_item_number(env, human_agent, "get lettuce")

    #     if firsttime_left_get_lettuce == True:
    #         reward_shaping_bonus = check_benevolence(env, best_action, action_left)
    #         if reward_shaping_bonus == 50:
    #             total_reward_bonus += reward_shaping_bonus
    #             reward_bonus_left = 20
    #             firsttime_left_get_lettuce = False



    if all(counter not in ("lettuce") for counter in counters):
        if not agent_item.holding:
            best_action = intelligently_find_item_number(env, agent_item, "get lettuce")
            if firsttime_right_go_to_knife == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_knife = False


        if agent_item.holding and isinstance(agent_item.holding, Lettuce):
            best_action = "go to counter"

            if firsttime_right_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_counter = False


        best_action = "go to counter"
        # reward_shaping_bonus = check_benevolence(env, best_action, action_left)
        # if reward_shaping_bonus == 50:
        #     reward_bonus_left = 100


        if firsttime_left_go_to_counter == True:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 50:
                total_reward_bonus += reward_shaping_bonus
                reward_bonus_left = 1000
                firsttime_left_go_to_counter = False


    return reward_bonus_left, reward_bonus_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter



rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]



mac_env_id = 'Overcooked-MA-v1'
env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "B_lowuncertainty",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}


# Initialize shared environment
shared_env = gym.make(mac_env_id, **env_params)




# Wrap each agent
env_agent_1 = SingleAgentWrapper_Without_Latent(shared_env, agent_index=1)

env_agent_0 = SingleAgentWrapper(shared_env, env_agent_1, agent_index=0)

ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 256,
    'batch_size': 128,
    'n_epochs': 10,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    # 'ent_coef': 0.05,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}


policy_kwargs = dict(
    net_arch=[dict(pi=[256, 128, 64], vf=[256, 128, 64])]
)




# 使用普通的 PPO 而不是 AlternatingPPO
# model = PPO("MlpPolicy", env_agent_0, policy_kwargs=policy_kwargs, **ppo_params)






# model = PPO.load("final_trained_models/[MapB_lowuncertainty]MEP_teamReward\model_20000000", env=env_agent_0)
# model = PPO.load("final_trained_models/[MapB_lowuncertainty]FCP_teamReward\model_20000000", env=env_agent_0)
model = PPO.load("final_trained_models/[MapB_lowuncertainty]POMDP\model_8000000", env=env_agent_0)




EVENT_TYPES = [
    "pass_plate",
    "pass_lettuce",
    "pass_lettuce",
    "pass_lettuce",
    "pass_chopped_lettuce",
    "pass_chopped_lettuce",
    "pass_chopped_lettuce",
    "pass_plated_lettuce",
    "pass_dirty_lettuce",
    "pass_chopped_dirty_lettuce",
    "pass_plated_dirty_lettuce",
    "make_clean_salad_alone",
    "make_dirty_salad_alone"
]




all_reward_list = []
# Test the loaded model

for iteration in range(10):
    cur_reward_list = []
    for rule_name in EVENT_TYPES:
        cummulated_reward = 0
        thistest_rewardlist = []

        env_agent_0.set_partner_idx(rule_name)  # 动态修改partner_idx，无需重新初始化

        obs = env_agent_0.reset()

        for step in range(400):
            # print(obs)
            action, _states = model.predict(obs, deterministic=True)

            # print(action)

            obs, rewards, dones, info = env_agent_0.step(action)
            cummulated_reward += rewards
            thistest_rewardlist.append(cummulated_reward)

            # print(cummulated_reward)
            # env_agent_0.render([(0, 0)])

            # time.sleep(0.1)

            if dones:
                break

        cur_reward_list.append(cummulated_reward)
    all_reward_list.append(float(np.mean(cur_reward_list)))
    
print(all_reward_list)


# FCP: [1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307, 1310.7692307692307]
# MEP: [1006.1538461538462, 1144.6153846153845, 1169.2307692307693, 1055.3846153846155, 1018.4615384615385, 1070.7692307692307, 1020.0, 1018.4615384615385, 1114.6153846153845, 1075.3846153846155]
# POMDP: [2024.6153846153845, 1938.4615384615386, 1976.923076923077, 1990.7692307692307, 1972.3076923076924, 1903.076923076923, 1924.6153846153845, 1847.6923076923076, 1958.4615384615386, 1952.3076923076924]
