import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate, BadLettuce
import random
import time
import torch
import os
import torch
import torch.nn as nn
from collections import deque
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import math


# ====== 全局随机种子 ======
SEED = 42  # 你可以修改这个数字来改变随机性

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)



EVENT_TYPES = [
    "pass_plate",
    "pass_lettuce",
    "pass_chopped_lettuce",
    "pass_plated_lettuce",
    "pass_dirty_lettuce",
    "pass_chopped_dirty_lettuce",
    "pass_plated_dirty_lettuce",
    "make_clean_salad_alone",
    "make_dirty_salad_alone",
    "only_take_lettuce",
    "only_chop_lettuce",
    "only_plate_salad"
]


class OneHotEventReward:
    def __init__(self, active_event_type: str, reward_value: float = 1.0):
        """
        :param active_event_type: str, one of the EVENT_TYPES
        :param reward_value: float, reward to give when that event occurs
        """
        assert active_event_type in EVENT_TYPES, f"Invalid event type: {active_event_type}"
        self.active_event = active_event_type
        self.reward_value = reward_value

    def generate_action(self, env) -> float:
        # if self._event_triggered(env, self.active_event) and first_time == True:
        action = self._event_triggered(env, self.active_event)
            # first_time = False
            # return self.reward_value, first_time
        return action

    def _event_triggered(self, env, event_type: str) -> bool:
        if event_type == "pass_plate":
            return self._is_passing_plate(env)
        elif event_type == "pass_lettuce":
            return self._is_passing_lettuce(env)
        elif event_type == "pass_chopped_lettuce":
            return self._is_passing_chopped_lettuce(env)
        elif event_type == "pass_plated_lettuce":
            return self._is_passing_plated_lettuce(env)
        elif event_type == "pass_dirty_lettuce":
            return self._is_passing_dirty_lettuce(env)
        elif event_type == "pass_chopped_dirty_lettuce":
            return self._is_passing_chopped_dirty_lettuce(env)
        elif event_type == "pass_plated_dirty_lettuce":
            return self._is_passing_plated_dirty_lettuce(env)
        elif event_type == "make_clean_salad_alone":
            return self._is_making_clean_salad_alone(env)
        elif event_type == "make_dirty_salad_alone":
            return self._is_making_dirty_salad_alone(env)
        elif event_type == "only_take_lettuce":
            return self._is_taking_but_not_chopping(env)
        elif event_type == "only_chop_lettuce":
            return self._is_chopping_but_not_plating(env)
        elif event_type == "only_plate_salad":
            return self._is_plating_but_not_serving(env)
        else:
            return False
        


    def _is_passing_plate(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and not agent.holding.containing:
            best_action = "go to counter"
            
        return macroActionDict[best_action]


    def _is_passing_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]
        


    def _is_passing_chopped_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]

        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]

        # best_action = "stay"
        # agent = env.agent[1]

        # knife = env.knife[1]
        # if not agent.holding and not knife.holding:
        #     best_action = "get plate 2"
        # if agent.holding and isinstance(agent.holding, Plate):
        #     best_action = "go to knife 2"

        # return macroActionDict[best_action]
    



    def _is_passing_plated_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"


        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], Lettuce):
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        if not agent.holding:
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_chopped_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]

        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:           
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_passing_plated_dirty_lettuce(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]



        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1 and not agent.holding and not knife0.holding.chopped) or (knife1.holding and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1 and not agent.holding and not knife1.holding.chopped):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"


        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], BadLettuce):
            best_action = "go to counter"
        return macroActionDict[best_action]


    def _is_making_clean_salad_alone(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get lettuce 2"
        if agent.holding and isinstance(agent.holding, Lettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get lettuce 2"


        if agent.holding and isinstance(agent.holding, Lettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], Lettuce):
            best_action = "deliver 2"
        return macroActionDict[best_action]



    def _is_making_dirty_salad_alone(self, env):
        best_action = "stay"
        agent = env.agent[1]
        knife0 = env.knife[0]
        knife1 = env.knife[1]


        if (not agent.holding and not knife0.holding) and (not agent.holding and not knife1.holding):
            best_action = "get badlettuce"
        if agent.holding and isinstance(agent.holding, BadLettuce) and not agent.holding.chopped:
            best_action = "go to knife 2"
        if (knife0.holding and not agent.holding and not knife0.holding.chopped and env._calDistance(agent.x, agent.y, knife0.x, knife0.y) == 1) or (knife1.holding and not agent.holding and not knife1.holding.chopped and env._calDistance(agent.x, agent.y, knife1.x, knife1.y) == 1):
            best_action = "chop"
        if (knife0.holding and not agent.holding and knife0.holding.chopped) or (knife1.holding and not agent.holding and knife1.holding.chopped):
            best_action = "get badlettuce"


        if agent.holding and isinstance(agent.holding, BadLettuce) and agent.holding.chopped:
            best_action = "get plate 1"
        if agent.holding and isinstance(agent.holding, Plate) and agent.holding.containing and isinstance(agent.holding.containing[0], BadLettuce):
            best_action = "deliver 2"
        return macroActionDict[best_action]









# 增加 ABI 模型保存数据的函数
def save_dataset_to_disk(data_buffer, filepath="abi_training_data.pt"):
    torch.save(data_buffer, filepath)
    print(f"✅ Data buffer saved to {filepath}")

def load_dataset_from_disk(filepath="abi_training_data.pt"):
    return torch.load(filepath)






# ===== 工具：Beta 与均匀先验的 KL，用于抑制过度自信 =====
def kl_beta_to_uniform(alpha, beta):
    # KL( Beta(a,b) || Beta(1,1) ) = -log B(a,b) + (a-1)psi(a) + (b-1)psi(b) - (a+b-2)psi(a+b)
    lgamma = torch.lgamma
    psi = torch.digamma
    a, b = alpha, beta
    return -(lgamma(a) + lgamma(b) - lgamma(a + b)) + (a - 1) * psi(a) + (b - 1) * psi(b) - (a + b - 2) * psi(a + b)

# ===== 二分类的 Evidential 损失：CE + λ*KL(Beta||Uniform) =====
def evidential_binary_loss(alpha, beta, y, lam=1e-3):
    # y ∈ {0,1}，形状 [B, 1]
    S = alpha + beta
    p = (alpha / S).clamp(1e-6, 1-1e-6)  # Beta 均值作为 P(y=1)
    ce = F.binary_cross_entropy(p, y, reduction='mean')
    kl = kl_beta_to_uniform(alpha, beta).mean()
    loss = ce + lam * kl
    metrics = {'ce': ce.item(), 'kl': kl.item(), 'S_mean': S.mean().item()}
    return loss, metrics



class LightweightTransformerABIModel_adaptivelength(nn.Module):
    """
    自适应长度 + 共享注意力；输出每个维度的 Beta 参数 (alpha, beta)。
    - 预测值：p = alpha / (alpha + beta)
    - 不确定性：S = alpha + beta（越小越不确定），或使用 Beta 方差。
    """
    def __init__(self, state_dim, hidden_dim=32, nhead=2, num_layers=1,
                 lr=1e-3, evidential_lambda=1e-3):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 共享注意力向量
        self.attn_vec = nn.Parameter(torch.randn(hidden_dim))

        # 三个维度的中间头（共享主干，不共享头部）
        self.head_A = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_B = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_I = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())

        # 每个维度输出 Beta 的 alpha / beta（softplus 保正）
        self.fc_A_alpha = nn.Linear(64, 1)
        self.fc_A_beta  = nn.Linear(64, 1)
        self.fc_B_alpha = nn.Linear(64, 1)
        self.fc_B_beta  = nn.Linear(64, 1)
        self.fc_I_alpha = nn.Linear(64, 1)
        self.fc_I_beta  = nn.Linear(64, 1)

        self.evidential_lambda = evidential_lambda
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.data_buffer = []

    @staticmethod
    def _make_pad_mask(lengths, max_len):
        device = lengths.device
        rng = torch.arange(max_len, device=device).unsqueeze(0)  # [1, T]
        pad_mask = rng >= lengths.unsqueeze(1)                   # [B, T], True=padding
        return pad_mask

    @staticmethod
    def _make_recent_window_mask(lengths, max_len, k_recent):
        device = lengths.device
        k = torch.clamp(torch.as_tensor(k_recent, device=device), min=1)
        start = torch.clamp(lengths - k, min=0)                          # [B]
        idx   = torch.arange(max_len, device=device).unsqueeze(0)        # [1, T]
        # 窗口内为 False（保留），其他 True（遮蔽）
        in_window = (idx >= start.unsqueeze(1)) & (idx < lengths.unsqueeze(1))
        mask = ~in_window                                                 # [B, T]
        return mask

    def _attn_pool(self, h, attn_vec, mask_bool, lengths):
        """
        h: [B, T, H]
        attn_vec: [H]
        mask_bool: [B, T] True=mask
        lengths: [B]
        """
        B, T, H = h.shape
        logits = torch.matmul(h, attn_vec)                 # [B, T]
        logits = logits.masked_fill(mask_bool, float('-inf'))

        # 兜底：若某一行全被 mask，则取“最后一个有效 token”
        all_masked = torch.isinf(logits).all(dim=1)        # [B]
        if all_masked.any():
            weights = torch.zeros_like(logits)
            last_idx = (lengths - 1).clamp(min=0)          # [B]
            weights[torch.arange(B, device=h.device), last_idx] = 1.0
        else:
            weights = torch.softmax(logits, dim=1)         # [B, T]

        pooled = torch.sum(h * weights.unsqueeze(-1), dim=1)  # [B, H]
        return pooled

    def forward(self, x, lengths=None, k_ability=15, k_benevolence=30, k_integrity=30):
        """
        x: [B, T, state_dim]（可包含 padding）
        lengths: [B] 每个序列真实长度；若为 None，默认全长有效
        返回：
          (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I) 形状均为 [B, 1]
        """
        B, T, _ = x.shape
        device = x.device
        if lengths is None:
            lengths = torch.full((B,), T, dtype=torch.long, device=device)

        # 1) 输入投影
        x_proj = self.input_proj(x)  # [B, T, H]

        # 2) Transformer 编码（带 padding mask）
        pad_mask = self._make_pad_mask(lengths, T)  # [B, T], True=padding
        h = self.transformer(x_proj, src_key_padding_mask=pad_mask)  # [B, T, H]

        # 3) 三个指标各自“最近 K 步”窗口（与 pad_mask 叠加）
        mask_A = pad_mask | self._make_recent_window_mask(lengths, T, k_ability)
        mask_B = pad_mask | self._make_recent_window_mask(lengths, T, k_benevolence)
        mask_I = pad_mask | self._make_recent_window_mask(lengths, T, k_integrity)

        # 4) 共享注意力向量做池化
        pooled_A = self._attn_pool(h, self.attn_vec, mask_A, lengths)  # [B, H]
        pooled_B = self._attn_pool(h, self.attn_vec, mask_B, lengths)
        pooled_I = self._attn_pool(h, self.attn_vec, mask_I, lengths)

        # 5) 维度头部
        zA = self.head_A(pooled_A)
        zB = self.head_B(pooled_B)
        zI = self.head_I(pooled_I)

        eps = 1e-4
        alpha_A = F.softplus(self.fc_A_alpha(zA)) + eps
        beta_A  = F.softplus(self.fc_A_beta(zA))  + eps
        alpha_B = F.softplus(self.fc_B_alpha(zB)) + eps
        beta_B  = F.softplus(self.fc_B_beta(zB))  + eps
        alpha_I = F.softplus(self.fc_I_alpha(zI)) + eps
        beta_I  = F.softplus(self.fc_I_beta(zI))  + eps

        return (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I)

    @staticmethod
    def beta_mean_strength(alpha, beta):
        S = alpha + beta
        p = (alpha / S).clamp(1e-6, 1-1e-6)
        return p, S

    def store_data(self, state_history, ability, benevolence, integrity):
        if isinstance(state_history, torch.Tensor):
            sh = state_history.detach().cpu()
        else:
            sh = torch.as_tensor(state_history)
        self.data_buffer.append((sh, float(ability), float(benevolence), float(integrity)))

    def _collate_batch(self, batch):
        seqs, A, B, I = [], [], [], []
        for sh, a, b, i in batch:
            t = torch.as_tensor(sh, dtype=torch.float32)
            seqs.append(t)
            A.append([a]); B.append([b]); I.append([i])
        lengths = torch.tensor([s.shape[0] for s in seqs], dtype=torch.long)
        states = pad_sequence(seqs, batch_first=True)  # [B, T_max, D]
        abilitys     = torch.tensor(A, dtype=torch.float32)
        benevolences = torch.tensor(B, dtype=torch.float32)
        integritys   = torch.tensor(I, dtype=torch.float32)
        return states, lengths, abilitys, benevolences, integritys

    def train_step(self, batch_size=None, k_ability=15, k_benevolence=30, k_integrity=30,
                   device=None, loss_weights=(1.0, 1.0, 1.0)):
        if not self.data_buffer:
            return None

        batch = self.data_buffer if batch_size is None else self.data_buffer[:batch_size]
        states, lengths, abilitys, benevolences, integritys = self._collate_batch(batch)

        if device is None:
            device = next(self.parameters()).device
        states       = states.to(device)
        lengths      = lengths.to(device)
        abilitys     = abilitys.to(device)
        benevolences = benevolences.to(device)
        integritys   = integritys.to(device)

        self.optimizer.zero_grad()
        (aA, bA), (aB, bB), (aI, bI) = self.forward(
            states, lengths=lengths,
            k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
        )

        # Evidential 损失（A/B/I 各自）
        lam = self.evidential_lambda
        wA, wB, wI = loss_weights
        loss_A, mA = evidential_binary_loss(aA, bA, abilitys,     lam=lam)
        loss_B, mB = evidential_binary_loss(aB, bB, benevolences, lam=lam)
        loss_I, mI = evidential_binary_loss(aI, bI, integritys,   lam=lam)

        total = wA*loss_A + wB*loss_B + wI*loss_I
        total.backward()
        self.optimizer.step()

        return {
            'total_loss': total.item(),
            'A_ce': mA['ce'], 'A_kl': mA['kl'], 'A_S_mean': mA['S_mean'],
            'B_ce': mB['ce'], 'B_kl': mB['kl'], 'B_S_mean': mB['S_mean'],
            'I_ce': mI['ce'], 'I_kl': mI['kl'], 'I_S_mean': mI['S_mean'],
        }

    def predict_beta_params(self, states_np, lengths_np=None,
                            k_ability=15, k_benevolence=30, k_integrity=30, device=None):
        """
        推理便捷接口：返回每维的 alpha/beta 及 p/S
        states_np: [B, T, D] (np 或 tensor)
        lengths_np: [B] (可选；不传则默认全长有效)
        """
        if not torch.is_tensor(states_np):
            states = torch.tensor(states_np, dtype=torch.float32)
        else:
            states = states_np.float()
        B, T, _ = states.shape

        if lengths_np is None:
            lengths = torch.full((B,), T, dtype=torch.long)
        else:
            lengths = torch.tensor(lengths_np, dtype=torch.long)

        if device is None:
            device = next(self.parameters()).device
        states  = states.to(device)
        lengths = lengths.to(device)

        with torch.no_grad():
            (aA, bA), (aB, bB), (aI, bI) = self.forward(
                states, lengths=lengths,
                k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
            )
            pA, SA = self.beta_mean_strength(aA, bA)
            pB, SB = self.beta_mean_strength(aB, bB)
            pI, SI = self.beta_mean_strength(aI, bI)

        out = {
            'A': {'alpha': aA.squeeze(-1), 'beta': bA.squeeze(-1), 'p': pA.squeeze(-1), 'S': SA.squeeze(-1)},
            'B': {'alpha': aB.squeeze(-1), 'beta': bB.squeeze(-1), 'p': pB.squeeze(-1), 'S': SB.squeeze(-1)},
            'I': {'alpha': aI.squeeze(-1), 'beta': bI.squeeze(-1), 'p': pI.squeeze(-1), 'S': SI.squeeze(-1)},
        }
        return out

    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'evidential_lambda': self.evidential_lambda
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path, map_location=None):
        checkpoint = torch.load(path, map_location=map_location)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.evidential_lambda = checkpoint.get('evidential_lambda', self.evidential_lambda)
        print(f"Model loaded from {path}")







class LightweightTransformerABIModel(nn.Module):
    def __init__(self, state_dim, seq_len=15, hidden_dim=32, nhead=2, num_layers=1):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'  # 更轻量
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling
        self.attention_vector = nn.Parameter(torch.randn(hidden_dim))  # 可学习注意力向量

        self.fc_ability = nn.Linear(hidden_dim, 1)
        self.fc_benevolence = nn.Linear(hidden_dim, 1)
        self.fc_integrity = nn.Linear(hidden_dim, 1)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.data_buffer = []

    def forward(self, x):
        # x: [batch, seq_len, state_dim]
        x = self.input_proj(x)           # [B, T, H]
        h = self.transformer(x)          # [B, T, H]

        # Attention pooling
        attn_weights = torch.softmax(torch.matmul(h, self.attention_vector), dim=1)  # [B, T]
        pooled = torch.sum(h * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        ability     = torch.sigmoid(self.fc_ability(pooled))
        benevolence = torch.sigmoid(self.fc_benevolence(pooled))
        integrity   = torch.sigmoid(self.fc_integrity(pooled))
        return ability, benevolence, integrity

    def store_data(self, state_history, ability, benevolence, integrity):
        self.data_buffer.append((state_history, ability, benevolence, integrity))


    def train_step(self):
        # print(self.data_buffer)
        if not self.data_buffer:
            # print('here')
            return None

        states, abilitys, benevolences, integritys = zip(*self.data_buffer)
        states = torch.tensor(states, dtype=torch.float32)
        abilitys = torch.tensor(abilitys, dtype=torch.float32).unsqueeze(-1)
        benevolences = torch.tensor(benevolences, dtype=torch.float32).unsqueeze(-1)
        integritys = torch.tensor(integritys, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_ability, pred_benevolence, pred_integrity = self.forward(states)

        # 分别计算各自 loss
        loss_ability = self.criterion(pred_ability, abilitys)
        loss_benevolence = self.criterion(pred_benevolence, benevolences)
        loss_integrity = self.criterion(pred_integrity, integritys)

        total_loss = loss_ability + loss_benevolence + loss_integrity
        total_loss.backward()
        self.optimizer.step()

        # self.data_buffer = []

        # 你可以返回一个 dict，方便调试记录
        return {
            'total_loss': total_loss.item(),
            'loss_ability': loss_ability.item(),
            'loss_benevolence': loss_benevolence.item(),
            'loss_integrity': loss_integrity.item()
        }



    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {path}")








class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0

        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]

        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        # Save and plot every save_freq steps
        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            # Plot moving average of last 100 episode rewards
            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) / (i - max(0, i - window + 1) + 1)
                              for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()

        return True








class SingleAgentWrapper_Without_Latent(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper_Without_Latent, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]


        """下面代码是进行【合作】的reward shaping，以及对inconsistency进行惩罚的部分，暂时注释掉"""
        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions = self.env._computeLowLevelActions(actions, 0)

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info







class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, partner_env, agent_index, state_dim, model_path=None):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


        self.partner_position_history = []



        self.partner_ability = None
        self.partner_benevolence = None
        self.partner_integrity = None

        self.predicted_ability = None
        self.predicted_benevolence = None
        self.predicted_integrity = None


        self.partner_A_alpha = 1
        self.partner_A_beta = 1

        self.partner_B_alpha = 1
        self.partner_B_beta = 1

        self.partner_I_alpha = 1
        self.partner_I_beta = 1





        self.rewarder = None




        self.model = LightweightTransformerABIModel_adaptivelength(state_dim)
        if model_path:
            self.model.load_model(model_path)

        self.state_history = deque(maxlen=30)

        self.partner_key_event_happen = False

        """增加obs维度"""
        # Assuming the original observation space is a Box space, extend it by 2 dimensions
        if isinstance(env.observation_space, gym.spaces.Box):
            num_extra_dims = 6  # 改成你想加多少维
            low = np.append(env.observation_space.low, [-np.inf] * num_extra_dims)
            high = np.append(env.observation_space.high, [np.inf] * num_extra_dims)
            # Create the new observation space with the extended bounds
            self.observation_space = gym.spaces.Box(low=low, high=high, dtype=env.observation_space.dtype)
        else:
            raise NotImplementedError("This wrapper only works with Box observation spaces.")
        
        

    def set_partner_idx(self, rule_name):
        self.rule_name = rule_name


    def reset(self):
        self.obs = self.env.reset()

        self.partner_position_history = []

        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        self.partner_A_alpha = 1
        self.partner_A_beta = 1

        self.partner_B_alpha = 1
        self.partner_B_beta = 1

        self.partner_I_alpha = 1
        self.partner_I_beta = 1



        self.state_history.clear()

        partner_cented_obs = self.env._get_macro_vector_obs_for_ABImodel()

        self.state_history.append(partner_cented_obs[1])

        self.partner_key_event_happen = False


        self.rewarder = OneHotEventReward(active_event_type=self.rule_name, reward_value=400)




        self.partner_A = 0.5
        self.partner_B = 0.5
        self.partner_I = 0.5

        self.partner_A_confidence = 0
        self.partner_B_confidence = 0
        self.partner_I_confidence = 0



        return np.concatenate([
            self.obs[self.agent_index],
            # [self.partner_ability, self.partner_benevolence, self.partner_integrity]
            [self.partner_A, self.partner_B, self.partner_I, self.partner_A_confidence, self.partner_B_confidence, self.partner_I_confidence]
            # [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]
            # [self.partner_ability, self.partner_ability, self.partner_ability, self.partner_ability, self.partner_ability, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_integrity, self.partner_integrity, self.partner_integrity, self.partner_integrity, self.partner_integrity]
            # [self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity]
            # [self.partner_A, self.partner_A_SD, self.partner_B, self.partner_B_SD, self.partner_I, self.partner_I_SD]
            # [self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, 
            #  self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, 
            #  self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I]
            # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
            # [self.partner_A, self.partner_B, self.partner_I]
            # [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
        ])
    






    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.rewarder.generate_action(self.env)

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action





        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)




        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)


        if self.env.macroActionName[real_execute_macro_actions[1]] == "get badlettuce" and self.env.badlettuce[0].x == 4:
            primary_actions[1 - self.agent_index] = 4



        # if (self.env.macroActionName[real_execute_macro_actions[1]] == "get lettuce 1" and self.env.lettuce[0].x == 4) or (self.env.macroActionName[real_execute_macro_actions[1]] == "get lettuce 2" and self.env.lettuce[1].x == 4):
        #     primary_actions[1 - self.agent_index] = 4

        # if (self.env.macroActionName[real_execute_macro_actions[1]] == "get plate 1" and self.env.plate[0].x == 4) or (self.env.macroActionName[real_execute_macro_actions[1]] == "get plate 2" and self.env.plate[1].x == 4):
        #     primary_actions[1 - self.agent_index] = 4            





        previous_partner_holding_status = self.env.agent[1].holding


        if len(self.partner_position_history) > 5 and self.partner_position_history[-5] == self.partner_position_history[-4] and self.partner_position_history[-4] == self.partner_position_history[-3] and self.partner_position_history[-3] == self.partner_position_history[-2] and self.partner_position_history[-2] == self.partner_position_history[-1]:
            
            to_choose = []

            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y + 1]] == "space":
                to_choose.append(0)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x + 1][self.env.agent[1].y]] == "space":
                to_choose.append(1)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y - 1]] == "space":
                to_choose.append(2)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x - 1][self.env.agent[1].y]] == "space":
                to_choose.append(3)

            primary_actions[1 - self.agent_index] = random.choice(to_choose)
        

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()


        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])



        after_partner_holding_status = self.env.agent[1].holding

        partner_cented_obs = self.env._get_macro_vector_obs_for_ABImodel()

        self.state_history.append(partner_cented_obs[1])

        if previous_partner_holding_status is not None and after_partner_holding_status is None:
            self.partner_key_event_happen = True
        else:
            self.partner_key_event_happen = False



        # 如果累计步数足够，开始预测
        # 如果该list长度达到5，就进行预测，注意，在step()里不训练感知模型，只让感知模型进行预测，加入到RL的obs中
        if len(self.state_history) == 30 and self.partner_key_event_happen == True:







            # --- 输入打包 ---
            state_np = np.asarray(self.state_history, dtype=np.float32)
            state_tensor = torch.from_numpy(state_np).unsqueeze(0)  # [1, T, D]

            # <<< 确保和模型在同一设备 >>>
            device = next(self.model.parameters()).device
            state_tensor = state_tensor.to(device)

            T_cur = state_np.shape[0]

            # 一些默认超参（若未定义则设置）
            if not hasattr(self, 'posterior_decay'):  self.posterior_decay = 0.999   # 遗忘系数 (仅用于可选的后验平滑)
            if not hasattr(self, 'kappa_max'):        self.kappa_max = 2.0           # 软证据上限 (仅用于可选的后验平滑)
            if not hasattr(self, 'conf_mode'):        self.conf_mode = "var"         # "var" | "strength" | "hybrid"
            if not hasattr(self, 'S_max'):            self.S_max = 50.0              # strength 映射阈值 (用于 conf 的第二种度量)

            # 初始化后验（若不存在）——注意：仅保留作“平滑的运行统计”，不再用于 var/conf
            def _ensure_beta(name):
                a = f'partner_{name}_alpha'
                b = f'partner_{name}_beta'
                if not hasattr(self, a): setattr(self, a, 1.0)  # Beta(1,1) 先验
                if not hasattr(self, b): setattr(self, b, 1.0)

            _ensure_beta('A'); _ensure_beta('B'); _ensure_beta('I')

            # ---- 工具：从 (a,b) 直接计算 mean/var/std/conf ----
            def _conf_from_ab(a: float, b: float):
                S = a + b
                # 避免极端数值
                if S <= 0.0:
                    mean = 0.5; var = 1.0/12.0  # Beta(1,1)的方差=1/12
                else:
                    mean = a / S
                    var  = (a * b) / (S * S * (S + 1.0))
                std = math.sqrt(var)
                # 确定度度量1：基于方差/标准差的归一化（0~1，方差越小越自信）
                conf_var = max(0.0, min(1.0, 1.0 - 2.0 * std))  # 1 - Std/0.5
                # 确定度度量2：基于强度 S 的平滑映射（S 越大越自信）
                conf_S   = S / (S + self.S_max)
                if self.conf_mode == "var":
                    conf = conf_var
                elif self.conf_mode == "strength":
                    conf = conf_S
                else:  # "hybrid"
                    conf = 0.5 * conf_var + 0.5 * conf_S
                return mean, var, std, conf, S

            # --- 前向：拿到 (alpha,beta) ---
            with torch.no_grad():
                if hasattr(self.model, "predict_beta_params"):
                    out = self.model.predict_beta_params(state_tensor, lengths_np=[T_cur])
                    aA = float(out['A']['alpha']); bA = float(out['A']['beta'])
                    aB = float(out['B']['alpha']); bB = float(out['B']['beta'])
                    aI = float(out['I']['alpha']); bI = float(out['I']['beta'])
                else:
                    (aA_t,bA_t),(aB_t,bB_t),(aI_t,bI_t) = self.model(state_tensor, lengths=torch.tensor([T_cur]))
                    aA, bA = float(aA_t.item()), float(bA_t.item())
                    aB, bB = float(aB_t.item()), float(bB_t.item())
                    aI, bI = float(aI_t.item()), float(bI_t.item())

            # === 关键变化：方差 & 确定度直接用模型的 (a,b) 计算 ===
            meanA, varA, stdA, confA, SA = _conf_from_ab(aA, bA)
            meanB, varB, stdB, confB, SB = _conf_from_ab(aB, bB)
            meanI, varI, stdI, confI, SI = _conf_from_ab(aI, bI)

            # 写回 —— “模型即刻估计”
            self.partner_A_prob = float(meanA)
            self.partner_B_prob = float(meanB)
            self.partner_I_prob = float(meanI)

            self.partner_A_SD   = float(stdA)
            self.partner_B_SD   = float(stdB)
            self.partner_I_SD   = float(stdI)

            self.partner_A_conf = float(confA)
            self.partner_B_conf = float(confB)
            self.partner_I_conf = float(confI)



            # ===（可选）仍然维护一个“平滑后验”，但只用于追踪/平滑prob，不再用于var/conf ===
            def _update_posterior_for_smoothing(name, p, S_model):
                # 用模型即时 mean 与强度 S 作为“软计数”更新（带遗忘），仅起到时间平滑效果
                a_name = f'partner_{name}_alpha'
                b_name = f'partner_{name}_beta'
                # 衰减
                setattr(self, a_name, getattr(self, a_name) * self.posterior_decay)
                setattr(self, b_name, getattr(self, b_name) * self.posterior_decay)
                # 软证据权重 κ
                kappa = min(S_model, self.kappa_max)
                # 累加软证据
                setattr(self, a_name, getattr(self, a_name) + kappa * p)
                setattr(self, b_name, getattr(self, b_name) + kappa * (1.0 - p))
                # 可写回一个“平滑后的概率”字段（避免覆盖模型即时估计）
                A = getattr(self, a_name); B = getattr(self, b_name)
                S_post = A + B
                p_smooth = A / S_post if S_post > 0 else 0.5
                setattr(self, f'partner_{name}_prob_smooth', float(p_smooth))

            # 如需：把模型的 mean 当作“观测”去平滑（不会影响上面的 var/conf）
            _update_posterior_for_smoothing('A', meanA, SA)
            _update_posterior_for_smoothing('B', meanB, SB)
            _update_posterior_for_smoothing('I', meanI, SI)


            # 若需要 ±1 标签
            self.partner_A = 1 if self.partner_A_prob_smooth >= 0.5 else -1
            self.partner_B = 1 if self.partner_B_prob_smooth >= 0.5 else -1
            self.partner_I = 1 if self.partner_I_prob_smooth >= 0.5 else -1

            # 兼容你后面使用的命名
            self.partner_A_confidence = self.partner_A_conf
            self.partner_B_confidence = self.partner_B_conf
            self.partner_I_confidence = self.partner_I_conf



            # print(predicted_ability)
            # print(predicted_benevolence)
            # print(predicted_integrity)

            # print(self.partner_A_prob)
            # print(self.partner_B_prob)
            # print(self.partner_I_prob)


            # print(self.partner_A)
            # print(self.partner_B)
            # print(self.partner_I)
            # print(self.partner_A_prob_smooth)
            # print(self.partner_B_prob_smooth)
            # print(self.partner_I_prob_smooth)
            # print(self.partner_A_conf)
            # print(self.partner_B_conf)
            # print(self.partner_I_conf)


            # print('============')

            # 这里存入ground truth是完全没有问题的，因为我还要抽取出label来做监督学习
            # self.model.store_data(list(self.state_history), 
            #                     self.partner_ability,
            #                     self.partner_benevolence,
            #                     self.partner_integrity)
       

        # self.partner_A = self.partner_ability
        # self.partner_B = self.partner_benevolence
        # self.partner_I = self.partner_integrity


        print('Ability: ', self.partner_A)
        print('Benevolence: ', self.partner_B)
        print('Integrity: ', self.partner_I)

        # self.partner_A = self.partner_ability
        # self.partner_B = self.partner_benevolence
        # self.partner_I = self.partner_integrity

        
        return np.concatenate([
            self.obs[self.agent_index],
            # [self.partner_ability, self.partner_benevolence, self.partner_integrity]
            [self.partner_A, self.partner_B, self.partner_I, self.partner_A_confidence, self.partner_B_confidence, self.partner_I_confidence]
            # [self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_ability, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_benevolence, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity, self.predicted_integrity]
            # [self.partner_A, self.partner_A_SD, self.partner_B, self.partner_B_SD, self.partner_I, self.partner_I_SD]
            # [self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, self.partner_A, 
            #  self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, self.partner_B, 
            #  self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I, self.partner_I]
            # [self.partner_A, self.partner_B, self.partner_I]
            # [self.partner_ability, self.partner_ability, self.partner_ability, self.partner_ability, self.partner_ability, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_benevolence, self.partner_integrity, self.partner_integrity, self.partner_integrity, self.partner_integrity, self.partner_integrity]
        ]), rewards[self.agent_index]+rewards[1-self.agent_index], dones, info
    




ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices = []
    distances = []
    
    # 收集可达的索引和对应距离
    if can_reach_1 != 4:
        reachable_indices.append(0)
        distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1)
        distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2)
        distances.append(distance_3)
    
    if not reachable_indices:
        return False  # 没有可达的位置
    
    # 如果只有一个可达位置，直接返回该索引
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    
    # 否则，返回距离最小的索引
    min_distance_index = reachable_indices[distances.index(min(distances))]
    return min_distance_index


def intelligently_find_item_number(env, agent_item, raw_name):

    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)


        best_action = "stay"

        min_distance_index = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)

        if min_distance_index == 0:
            best_action = raw_name + " 1"
        if min_distance_index == 1:
            best_action = raw_name + " 2"
        if min_distance_index == 2:
            best_action = "get dirty plate"
        # print('----', best_action)
        return best_action

    

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4:
        best_action = raw_name + " 2"

    if can_reach_1 != 4 and can_reach_2 == 4:
        best_action = raw_name + " 1"

    if can_reach_1 != 4 and can_reach_2 != 4:
        if distance_1 <= distance_2:
            best_action = raw_name + " 1"

        else:
            best_action = raw_name + " 2"
    
    return best_action



def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 50
    return env.reward



def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):

    agent_item = env.agent[1]
    human_agent = env.agent[0]

    counter1_x = 12
    counter1_y = 7

    counter2_x = 13
    counter2_y = 7


    counter1 = ITEMNAME[env.map[counter1_x][counter1_y]]
    counter2 = ITEMNAME[env.map[counter2_x][counter2_y]]


    reward_shaping_bonus = 0

    total_reward_bonus = 0

    reward_bonus_left = 0
    reward_bonus_right = 0

    """右侧high benevolence"""
    counters = [counter1, counter2]
    # if any(counter in ("lettuce") for counter in counters):
    #     best_action = intelligently_find_item_number(env, human_agent, "get lettuce")

    #     if firsttime_left_get_lettuce == True:
    #         reward_shaping_bonus = check_benevolence(env, best_action, action_left)
    #         if reward_shaping_bonus == 50:
    #             total_reward_bonus += reward_shaping_bonus
    #             reward_bonus_left = 20
    #             firsttime_left_get_lettuce = False



    if all(counter not in ("lettuce") for counter in counters):
        if not agent_item.holding:
            best_action = intelligently_find_item_number(env, agent_item, "get lettuce")
            if firsttime_right_go_to_knife == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_knife = False


        if agent_item.holding and isinstance(agent_item.holding, Lettuce):
            best_action = "go to counter"

            if firsttime_right_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_counter = False


        best_action = "go to counter"
        # reward_shaping_bonus = check_benevolence(env, best_action, action_left)
        # if reward_shaping_bonus == 50:
        #     reward_bonus_left = 100


        if firsttime_left_go_to_counter == True:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 50:
                total_reward_bonus += reward_shaping_bonus
                reward_bonus_left = 1000
                firsttime_left_go_to_counter = False


    return reward_bonus_left, reward_bonus_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter




rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]



mac_env_id = 'Overcooked-MA-v0'
env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "A",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}


# Initialize shared environment
shared_env = gym.make(mac_env_id, **env_params)




# Wrap each agent
env_agent_1 = SingleAgentWrapper_Without_Latent(shared_env, agent_index=1)


env_agent_0 = SingleAgentWrapper(shared_env, env_agent_1, agent_index=0, state_dim=6, model_path="[MapA]abi_model_outputs_Boltzmaan_10times_adaptivelength_FFFFFFFFFFFFFFINAL_with_uncertainty/abi_model_step_10000.pt")



ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 256,
    'batch_size': 128,
    'n_epochs': 10,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    # 'ent_coef': 0.05,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}


policy_kwargs = dict(
    net_arch=[dict(pi=[256, 128, 64], vf=[256, 128, 64])]
)




model = PPO.load("final_trained_models/[MapA]TrustPOMDP/model_8000000", env=env_agent_0)






EVENT_TYPES = [
    "pass_plate",
    "pass_lettuce",
    "pass_lettuce",
    "pass_lettuce",
    "pass_chopped_lettuce",
    "pass_chopped_lettuce",
    "pass_chopped_lettuce",
    "pass_plated_lettuce",
    "pass_dirty_lettuce",
    "pass_chopped_dirty_lettuce",
    "pass_plated_dirty_lettuce",
    "make_clean_salad_alone",
    "make_dirty_salad_alone"
]



all_reward_list = []
# Test the loaded model

for iteration in range(10):
    cur_reward_list = []
    for rule_name in EVENT_TYPES:
        cummulated_reward = 0
        thistest_rewardlist = []

        env_agent_0.set_partner_idx(rule_name)  # 动态修改partner_idx，无需重新初始化

        obs = env_agent_0.reset()

        for step in range(400):
            # print(obs)
            action, _states = model.predict(obs, deterministic=True)

            # print(action)

            obs, rewards, dones, info = env_agent_0.step(action)
            cummulated_reward += rewards
            thistest_rewardlist.append(cummulated_reward)

            # print(cummulated_reward)
            # env_agent_0.render([(0, 0)])

            # time.sleep(0.1)

            if dones:
                break

        cur_reward_list.append(cummulated_reward)
    all_reward_list.append(float(np.mean(cur_reward_list)))
    
print(all_reward_list)


# [1345.3846153846155, 1491.5384615384614, 1309.2307692307693, 1363.076923076923, 1273.8461538461538, 1533.076923076923, 1346.923076923077, 1382.3076923076924, 1343.076923076923, 1327.6923076923076]

