import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate
import random
import time
import torch
import os
import torch
import torch.nn as nn
from collections import deque
import math

import matplotlib
matplotlib.use("Agg")  # 非交互后端，防止子进程导入时卡住/崩掉
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


# train_parallel.py
import os
# —— 锁线程数，防止并行库抢核 —— #
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

import gym
import numpy as np
import random
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from collections import deque
import time
import matplotlib.pyplot as plt

# ====== 你已有的类/函数：务必放在同一文件或正确 import ======
# - ABIGatedExtractor
# - LightweightTransformerABIModel
# - EpisodeRewardCallback
# - SingleAgentWrapper_Without_Latent
# - SingleAgentWrapper
# - 你的 Overcooked 环境注册 id 及 ITEMNAME、宏动作等
# 直接复用你上一条消息里的定义即可，不要删。

# ====== 全局随机种子 ======
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# （可选）让 PyTorch 在 CPU 上只用 1 线程
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class ABIGatedExtractor(BaseFeaturesExtractor):
    """
    观测最后三维为 A,B,I ∈ [-1, +1]（建议 -1/ +1）。
    输出特征: [f*A_pos, f*A_neg, f*B_pos, f*B_neg, f*I_pos, f*I_neg, A, B, I]
    其中 f = MLP( non-ABI obs )，z_pos = relu(z), z_neg = relu(-z).
    """
    def __init__(self, observation_space: gym.spaces.Box, base_dim: int = 64):
        self.obs_dim = observation_space.shape[0]
        assert self.obs_dim >= 4, "需要: 非ABI>=1 + ABI 3 维"
        self.base_dim = base_dim
        # 最终输出维度 = 6*base_dim + 3
        super().__init__(observation_space, features_dim=6 * base_dim + 3)

        self.backbone = nn.Sequential(
            nn.Linear(self.obs_dim - 3, 256), nn.ReLU(),
            nn.Linear(256, base_dim), nn.ReLU()
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        x = obs[..., : self.obs_dim - 3]          # 非 ABI
        A = obs[..., self.obs_dim - 3 : self.obs_dim - 2]
        B = obs[..., self.obs_dim - 2 : self.obs_dim - 1]
        I = obs[..., self.obs_dim - 1 : self.obs_dim]

        f = self.backbone(x)                      # [B, D]

        A_pos, A_neg = torch.relu(A), torch.relu(-A)
        B_pos, B_neg = torch.relu(B), torch.relu(-B)
        I_pos, I_neg = torch.relu(I), torch.relu(-I)

        # 标量门控（逐元素缩放）
        out = torch.cat([
            f * A_pos, f * A_neg,
            f * B_pos, f * B_neg,
            f * I_pos, f * I_neg,
            A, B, I
        ], dim=-1)
        return out



class ABIGatedExtractorWithConf(BaseFeaturesExtractor):
    """
    观测末尾 6 维：
      [A, B, I, conf_A, conf_B, conf_I]
      其中 A/B/I ∈ [-1,1]（你现在用的是 ±1）， conf_* ∈ [0,1]（1=很确定, 0=很不确定）

    门控逻辑：
      gate_A_pos = relu(A) * conf_A
      gate_A_neg = relu(-A) * conf_A
      （B/I 同理）
      -> 用上述 gate 缩放共享特征 f，再拼上原始 A/B/I 与 conf_* 作为原始上下文。

    输出维度：6*base_dim + 6
    """
    def __init__(self, observation_space: gym.spaces.Box, base_dim: int = 64, conf_power: float = 1.0):
        self.obs_dim = observation_space.shape[0]
        assert self.obs_dim >= 6, "obs 需要包含至少 6 个 ABI 相关维度: [A,B,I,conf_A,conf_B,conf_I]"
        super().__init__(observation_space, features_dim=6 * base_dim + 6)

        self.base_dim = base_dim
        self.conf_power = conf_power  # conf^gamma，gamma>1时更保守

        # 非 ABI 部分的编码器
        self.backbone = nn.Sequential(
            nn.Linear(self.obs_dim - 6, 256), nn.ReLU(),
            nn.Linear(256, base_dim), nn.ReLU()
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        # 切分观测
        x = obs[..., : self.obs_dim - 6]                 # 非 ABI
        A = obs[..., self.obs_dim - 6 : self.obs_dim - 5]
        B = obs[..., self.obs_dim - 5 : self.obs_dim - 4]
        I = obs[..., self.obs_dim - 4 : self.obs_dim - 3]
        cA = obs[..., self.obs_dim - 3 : self.obs_dim - 2].clamp(0.0, 1.0)
        cB = obs[..., self.obs_dim - 2 : self.obs_dim - 1].clamp(0.0, 1.0)
        cI = obs[..., self.obs_dim - 1 : self.obs_dim     ].clamp(0.0, 1.0)

        # 共享特征
        f = self.backbone(x)  # [B, D]

        # 正/负门 + 确定度调制（不确定时减弱门控）
        # 可选对 conf 做指数：conf^gamma，gamma>1→更保守
        # if self.conf_power != 1.0:
        #     cA = cA.pow(self.conf_power)
        #     cB = cB.pow(self.conf_power)
        #     cI = cI.pow(self.conf_power)

        A_pos, A_neg = torch.relu(A), torch.relu(-A)
        B_pos, B_neg = torch.relu(B), torch.relu(-B)
        I_pos, I_neg = torch.relu(I), torch.relu(-I)

        # gate_A_pos = A_pos * cA
        # gate_A_neg = A_neg * cA
        # gate_B_pos = B_pos * cB
        # gate_B_neg = B_neg * cB
        # gate_I_pos = I_pos * cI
        # gate_I_neg = I_neg * cI

        gate_A_pos = A_pos
        gate_A_neg = A_neg
        gate_B_pos = B_pos
        gate_B_neg = B_neg
        gate_I_pos = I_pos
        gate_I_neg = I_neg


        # 应用门控（逐元素缩放）
        out = torch.cat([
            f * gate_A_pos, f * gate_A_neg,
            f * gate_B_pos, f * gate_B_neg,
            f * gate_I_pos, f * gate_I_neg,
            A, B, I,       # 原始 ABI 符号（±1）
            cA, cB, cI     # 确定度（0..1）
        ], dim=-1)
        return out
    


# 增加 ABI 模型保存数据的函数
def save_dataset_to_disk(data_buffer, filepath="abi_training_data.pt"):
    torch.save(data_buffer, filepath)
    print(f"✅ Data buffer saved to {filepath}")

def load_dataset_from_disk(filepath="abi_training_data.pt"):
    return torch.load(filepath)





# ===== 工具：Beta 与均匀先验的 KL，用于抑制过度自信 =====
def kl_beta_to_uniform(alpha, beta):
    # KL( Beta(a,b) || Beta(1,1) ) = -log B(a,b) + (a-1)psi(a) + (b-1)psi(b) - (a+b-2)psi(a+b)
    lgamma = torch.lgamma
    psi = torch.digamma
    a, b = alpha, beta
    return -(lgamma(a) + lgamma(b) - lgamma(a + b)) + (a - 1) * psi(a) + (b - 1) * psi(b) - (a + b - 2) * psi(a + b)

# ===== 二分类的 Evidential 损失：CE + λ*KL(Beta||Uniform) =====
def evidential_binary_loss(alpha, beta, y, lam=1e-3):
    # y ∈ {0,1}，形状 [B, 1]
    S = alpha + beta
    p = (alpha / S).clamp(1e-6, 1-1e-6)  # Beta 均值作为 P(y=1)
    ce = F.binary_cross_entropy(p, y, reduction='mean')
    kl = kl_beta_to_uniform(alpha, beta).mean()
    loss = ce + lam * kl
    metrics = {'ce': ce.item(), 'kl': kl.item(), 'S_mean': S.mean().item()}
    return loss, metrics

class LightweightTransformerABIModel_adaptivelength(nn.Module):
    """
    自适应长度 + 共享注意力；输出每个维度的 Beta 参数 (alpha, beta)。
    - 预测值：p = alpha / (alpha + beta)
    - 不确定性：S = alpha + beta（越小越不确定），或使用 Beta 方差。
    """
    def __init__(self, state_dim, hidden_dim=32, nhead=2, num_layers=1,
                 lr=1e-3, evidential_lambda=1e-3):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 共享注意力向量
        self.attn_vec = nn.Parameter(torch.randn(hidden_dim))

        # 三个维度的中间头（共享主干，不共享头部）
        self.head_A = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_B = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_I = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())

        # 每个维度输出 Beta 的 alpha / beta（softplus 保正）
        self.fc_A_alpha = nn.Linear(64, 1)
        self.fc_A_beta  = nn.Linear(64, 1)
        self.fc_B_alpha = nn.Linear(64, 1)
        self.fc_B_beta  = nn.Linear(64, 1)
        self.fc_I_alpha = nn.Linear(64, 1)
        self.fc_I_beta  = nn.Linear(64, 1)

        self.evidential_lambda = evidential_lambda
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.data_buffer = []

    @staticmethod
    def _make_pad_mask(lengths, max_len):
        device = lengths.device
        rng = torch.arange(max_len, device=device).unsqueeze(0)  # [1, T]
        pad_mask = rng >= lengths.unsqueeze(1)                   # [B, T], True=padding
        return pad_mask

    @staticmethod
    def _make_recent_window_mask(lengths, max_len, k_recent):
        device = lengths.device
        k = torch.clamp(torch.as_tensor(k_recent, device=device), min=1)
        start = torch.clamp(lengths - k, min=0)                          # [B]
        idx   = torch.arange(max_len, device=device).unsqueeze(0)        # [1, T]
        # 窗口内为 False（保留），其他 True（遮蔽）
        in_window = (idx >= start.unsqueeze(1)) & (idx < lengths.unsqueeze(1))
        mask = ~in_window                                                 # [B, T]
        return mask

    def _attn_pool(self, h, attn_vec, mask_bool, lengths):
        """
        h: [B, T, H]
        attn_vec: [H]
        mask_bool: [B, T] True=mask
        lengths: [B]
        """
        B, T, H = h.shape
        logits = torch.matmul(h, attn_vec)                 # [B, T]
        logits = logits.masked_fill(mask_bool, float('-inf'))

        # 兜底：若某一行全被 mask，则取“最后一个有效 token”
        all_masked = torch.isinf(logits).all(dim=1)        # [B]
        if all_masked.any():
            weights = torch.zeros_like(logits)
            last_idx = (lengths - 1).clamp(min=0)          # [B]
            weights[torch.arange(B, device=h.device), last_idx] = 1.0
        else:
            weights = torch.softmax(logits, dim=1)         # [B, T]

        pooled = torch.sum(h * weights.unsqueeze(-1), dim=1)  # [B, H]
        return pooled

    def forward(self, x, lengths=None, k_ability=15, k_benevolence=30, k_integrity=30):
        """
        x: [B, T, state_dim]（可包含 padding）
        lengths: [B] 每个序列真实长度；若为 None，默认全长有效
        返回：
          (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I) 形状均为 [B, 1]
        """
        B, T, _ = x.shape
        device = x.device
        if lengths is None:
            lengths = torch.full((B,), T, dtype=torch.long, device=device)

        # 1) 输入投影
        x_proj = self.input_proj(x)  # [B, T, H]

        # 2) Transformer 编码（带 padding mask）
        pad_mask = self._make_pad_mask(lengths, T)  # [B, T], True=padding
        h = self.transformer(x_proj, src_key_padding_mask=pad_mask)  # [B, T, H]

        # 3) 三个指标各自“最近 K 步”窗口（与 pad_mask 叠加）
        mask_A = pad_mask | self._make_recent_window_mask(lengths, T, k_ability)
        mask_B = pad_mask | self._make_recent_window_mask(lengths, T, k_benevolence)
        mask_I = pad_mask | self._make_recent_window_mask(lengths, T, k_integrity)

        # 4) 共享注意力向量做池化
        pooled_A = self._attn_pool(h, self.attn_vec, mask_A, lengths)  # [B, H]
        pooled_B = self._attn_pool(h, self.attn_vec, mask_B, lengths)
        pooled_I = self._attn_pool(h, self.attn_vec, mask_I, lengths)

        # 5) 维度头部
        zA = self.head_A(pooled_A)
        zB = self.head_B(pooled_B)
        zI = self.head_I(pooled_I)

        eps = 1e-4
        alpha_A = F.softplus(self.fc_A_alpha(zA)) + eps
        beta_A  = F.softplus(self.fc_A_beta(zA))  + eps
        alpha_B = F.softplus(self.fc_B_alpha(zB)) + eps
        beta_B  = F.softplus(self.fc_B_beta(zB))  + eps
        alpha_I = F.softplus(self.fc_I_alpha(zI)) + eps
        beta_I  = F.softplus(self.fc_I_beta(zI))  + eps

        return (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I)

    @staticmethod
    def beta_mean_strength(alpha, beta):
        S = alpha + beta
        p = (alpha / S).clamp(1e-6, 1-1e-6)
        return p, S

    def store_data(self, state_history, ability, benevolence, integrity):
        if isinstance(state_history, torch.Tensor):
            sh = state_history.detach().cpu()
        else:
            sh = torch.as_tensor(state_history)
        self.data_buffer.append((sh, float(ability), float(benevolence), float(integrity)))

    def _collate_batch(self, batch):
        seqs, A, B, I = [], [], [], []
        for sh, a, b, i in batch:
            t = torch.as_tensor(sh, dtype=torch.float32)
            seqs.append(t)
            A.append([a]); B.append([b]); I.append([i])
        lengths = torch.tensor([s.shape[0] for s in seqs], dtype=torch.long)
        states = pad_sequence(seqs, batch_first=True)  # [B, T_max, D]
        abilitys     = torch.tensor(A, dtype=torch.float32)
        benevolences = torch.tensor(B, dtype=torch.float32)
        integritys   = torch.tensor(I, dtype=torch.float32)
        return states, lengths, abilitys, benevolences, integritys

    def train_step(self, batch_size=None, k_ability=15, k_benevolence=30, k_integrity=30,
                   device=None, loss_weights=(1.0, 1.0, 1.0)):
        if not self.data_buffer:
            return None

        batch = self.data_buffer if batch_size is None else self.data_buffer[:batch_size]
        states, lengths, abilitys, benevolences, integritys = self._collate_batch(batch)

        if device is None:
            device = next(self.parameters()).device
        states       = states.to(device)
        lengths      = lengths.to(device)
        abilitys     = abilitys.to(device)
        benevolences = benevolences.to(device)
        integritys   = integritys.to(device)

        self.optimizer.zero_grad()
        (aA, bA), (aB, bB), (aI, bI) = self.forward(
            states, lengths=lengths,
            k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
        )

        # Evidential 损失（A/B/I 各自）
        lam = self.evidential_lambda
        wA, wB, wI = loss_weights
        loss_A, mA = evidential_binary_loss(aA, bA, abilitys,     lam=lam)
        loss_B, mB = evidential_binary_loss(aB, bB, benevolences, lam=lam)
        loss_I, mI = evidential_binary_loss(aI, bI, integritys,   lam=lam)

        total = wA*loss_A + wB*loss_B + wI*loss_I
        total.backward()
        self.optimizer.step()

        return {
            'total_loss': total.item(),
            'A_ce': mA['ce'], 'A_kl': mA['kl'], 'A_S_mean': mA['S_mean'],
            'B_ce': mB['ce'], 'B_kl': mB['kl'], 'B_S_mean': mB['S_mean'],
            'I_ce': mI['ce'], 'I_kl': mI['kl'], 'I_S_mean': mI['S_mean'],
        }

    def predict_beta_params(self, states_np, lengths_np=None,
                            k_ability=15, k_benevolence=30, k_integrity=30, device=None):
        """
        推理便捷接口：返回每维的 alpha/beta 及 p/S
        states_np: [B, T, D] (np 或 tensor)
        lengths_np: [B] (可选；不传则默认全长有效)
        """
        if not torch.is_tensor(states_np):
            states = torch.tensor(states_np, dtype=torch.float32)
        else:
            states = states_np.float()
        B, T, _ = states.shape

        if lengths_np is None:
            lengths = torch.full((B,), T, dtype=torch.long)
        else:
            lengths = torch.tensor(lengths_np, dtype=torch.long)

        if device is None:
            device = next(self.parameters()).device
        states  = states.to(device)
        lengths = lengths.to(device)

        with torch.no_grad():
            (aA, bA), (aB, bB), (aI, bI) = self.forward(
                states, lengths=lengths,
                k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
            )
            pA, SA = self.beta_mean_strength(aA, bA)
            pB, SB = self.beta_mean_strength(aB, bB)
            pI, SI = self.beta_mean_strength(aI, bI)

        out = {
            'A': {'alpha': aA.squeeze(-1), 'beta': bA.squeeze(-1), 'p': pA.squeeze(-1), 'S': SA.squeeze(-1)},
            'B': {'alpha': aB.squeeze(-1), 'beta': bB.squeeze(-1), 'p': pB.squeeze(-1), 'S': SB.squeeze(-1)},
            'I': {'alpha': aI.squeeze(-1), 'beta': bI.squeeze(-1), 'p': pI.squeeze(-1), 'S': SI.squeeze(-1)},
        }
        return out

    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'evidential_lambda': self.evidential_lambda
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path, map_location=None):
        checkpoint = torch.load(path, map_location=map_location)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.evidential_lambda = checkpoint.get('evidential_lambda', self.evidential_lambda)
        print(f"Model loaded from {path}")






class LightweightTransformerABIModel(nn.Module):
    def __init__(self, state_dim, seq_len=15, hidden_dim=32, nhead=2, num_layers=1):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'  # 更轻量
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling
        self.attention_vector = nn.Parameter(torch.randn(hidden_dim))  # 可学习注意力向量

        self.fc_ability = nn.Linear(hidden_dim, 1)
        self.fc_benevolence = nn.Linear(hidden_dim, 1)
        self.fc_integrity = nn.Linear(hidden_dim, 1)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.data_buffer = []

    def forward(self, x):
        # x: [batch, seq_len, state_dim]
        x = self.input_proj(x)           # [B, T, H]
        h = self.transformer(x)          # [B, T, H]

        # Attention pooling
        attn_weights = torch.softmax(torch.matmul(h, self.attention_vector), dim=1)  # [B, T]
        pooled = torch.sum(h * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        ability     = torch.sigmoid(self.fc_ability(pooled))
        benevolence = torch.sigmoid(self.fc_benevolence(pooled))
        integrity   = torch.sigmoid(self.fc_integrity(pooled))
        return ability, benevolence, integrity

    def store_data(self, state_history, ability, benevolence, integrity):
        self.data_buffer.append((state_history, ability, benevolence, integrity))


    def train_step(self):
        # print(self.data_buffer)
        if not self.data_buffer:
            # print('here')
            return None

        states, abilitys, benevolences, integritys = zip(*self.data_buffer)
        states = torch.tensor(states, dtype=torch.float32)
        abilitys = torch.tensor(abilitys, dtype=torch.float32).unsqueeze(-1)
        benevolences = torch.tensor(benevolences, dtype=torch.float32).unsqueeze(-1)
        integritys = torch.tensor(integritys, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_ability, pred_benevolence, pred_integrity = self.forward(states)

        # 分别计算各自 loss
        loss_ability = self.criterion(pred_ability, abilitys)
        loss_benevolence = self.criterion(pred_benevolence, benevolences)
        loss_integrity = self.criterion(pred_integrity, integritys)

        total_loss = loss_ability + loss_benevolence + loss_integrity
        total_loss.backward()
        self.optimizer.step()

        # self.data_buffer = []

        # 你可以返回一个 dict，方便调试记录
        return {
            'total_loss': total_loss.item(),
            'loss_ability': loss_ability.item(),
            'loss_benevolence': loss_benevolence.item(),
            'loss_integrity': loss_integrity.item()
        }



    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {path}")








class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0

        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]

        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        # Save and plot every save_freq steps
        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            # Plot moving average of last 100 episode rewards
            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) / (i - max(0, i - window + 1) + 1)
                              for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()

        return True








class SingleAgentWrapper_Without_Latent(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper_Without_Latent, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]


        """下面代码是进行【合作】的reward shaping，以及对inconsistency进行惩罚的部分，暂时注释掉"""
        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions = self.env._computeLowLevelActions(actions, 0)

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info





def get_boltzmann_probs_from_policy(model, obs, beta: float = 1.0):
    """
    输入 RL 模型、obs 和 beta，返回各动作的 Boltzmann 概率 (numpy array)。
    """
    obs_t, _ = model.policy.obs_to_tensor(obs)
    with torch.no_grad():
        dist = model.policy.get_distribution(obs_t)     # SB3 distribution
        logits = dist.distribution.logits               # [B, n_actions]
        scaled_logits = logits * beta                   # 控制理性程度
        probs = torch.softmax(scaled_logits, dim=-1).cpu().numpy()[0]

        # 打印（按概率降序）

        # action_probs = list(zip(macroActionName, probs))
        # action_probs.sort(key=lambda x: x[1], reverse=True)
        # print("Boltzmann Probabilities (beta={}):".format(beta))
        # for name, p in action_probs:
        #     print(f"{name:<20s} {p:.4f}")


    return probs

def sample_action_from_probs(probs: np.ndarray):
    """
    输入概率数组 (长度 = 动作数)，返回采样到的动作索引 (int)。
    """
    return int(np.random.choice(len(probs), p=probs))





class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, partner_env, agent_index):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None

        self.partner_position_history = []

        self.action_probabilities = []

        agent_lists = [
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_highI_version1_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_highI_version2_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_lowI_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_lowB_highI_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_lowB_lowI_agent1\model_4100000",

            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_highI_version1_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_highI_version2_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_highB_lowI_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_lowB_highI_agent1\model_4100000",
            "final_trained_models/[MapA_lowuncertainty]trustee_agent_lowB_lowI_agent1\model_4100000",
        ]



        # 加载所有模型
        self.models = {}
        for i, path in enumerate(agent_lists):
            model_name = f"model{i}"
            self.models[model_name] = PPO.load(path, env=partner_env)
            print(f"Loaded {model_name} from {path}")

        print(self.models)


    def reset(self):
        self.obs = self.env.reset()


        self.partner_position_history = []

        self.action_probabilities = []


        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        random.seed()
        partner_idx = random.randint(0, len(self.models)-1)

        self.other_agent_model = self.models["model"+str(partner_idx)]



        if partner_idx == 0:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 1:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 2:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = -1

        if partner_idx == 3:
            self.partner_ability = 1
            self.partner_benevolence = -1
            self.partner_integrity = 1

        if partner_idx == 4:
            self.partner_ability = 1
            self.partner_benevolence = -1
            self.partner_integrity = -1

        if partner_idx == 5:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 6:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 7:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = -1

        if partner_idx == 8:
            self.partner_ability = -1
            self.partner_benevolence = -1
            self.partner_integrity = 1

        if partner_idx == 9:
            self.partner_ability = -1
            self.partner_benevolence = -1
            self.partner_integrity = -1

        if partner_idx > 9:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1



        return self.obs[self.agent_index]
    


    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        boltzmann_actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]



        """只有切换high level action的时候才计算一次probability"""
        """注意，需要在执行_computeLowLevelActions之前就判断是否self.env.macroAgent[1].cur_macro_action_done == True"""
        if self.partner_ability == -1 and self.env.macroAgent[1].cur_macro_action_done == True:
            # print('更新概率了')
            self.action_probabilities = get_boltzmann_probs_from_policy(self.other_agent_model, self.obs[1 - self.agent_index], beta = 0.3)


        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)

        low_level_action_to_execute = primary_actions




        """只有切换high level action的时候才计算一次probability"""
        if self.partner_ability == -1 and self.env.macroAgent[1].cur_macro_action_done != True:
            other_agent_action_boltzmaan = sample_action_from_probs(self.action_probabilities)

            boltzmann_actions[self.agent_index] = action
            boltzmann_actions[1 - self.agent_index] = other_agent_action_boltzmaan


            # print('加了玻尔兹曼之后的', boltzmann_actions)


            primary_actions_boltzmaan, real_execute_macro_actions = self.env._computeLowLevelActions_boltzmaanlowlevel(boltzmann_actions)

            low_level_action_to_execute = [primary_actions[0], primary_actions_boltzmaan[1]]

           


        previous_partner_holding_status = self.env.agent[1].holding

        if len(self.partner_position_history) > 5 and self.partner_position_history[-5] == self.partner_position_history[-4] and self.partner_position_history[-4] == self.partner_position_history[-3] and self.partner_position_history[-3] == self.partner_position_history[-2] and self.partner_position_history[-2] == self.partner_position_history[-1]:
            
            to_choose = []

            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y + 1]] == "space":
                to_choose.append(0)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x + 1][self.env.agent[1].y]] == "space":
                to_choose.append(1)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y - 1]] == "space":
                to_choose.append(2)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x - 1][self.env.agent[1].y]] == "space":
                to_choose.append(3)

            low_level_action_to_execute[1 - self.agent_index] = random.choice(to_choose)





        self.obs, rewards, dones, info = self.env.step(low_level_action_to_execute)


        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + rewards[1 - self.agent_index], dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + rewards[1 - self.agent_index], dones, info






ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices = []
    distances = []
    
    # 收集可达的索引和对应距离
    if can_reach_1 != 4:
        reachable_indices.append(0)
        distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1)
        distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2)
        distances.append(distance_3)
    
    if not reachable_indices:
        return False  # 没有可达的位置
    
    # 如果只有一个可达位置，直接返回该索引
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    
    # 否则，返回距离最小的索引
    min_distance_index = reachable_indices[distances.index(min(distances))]
    return min_distance_index


def intelligently_find_item_number(env, agent_item, raw_name):

    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)


        best_action = "stay"

        min_distance_index = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)

        if min_distance_index == 0:
            best_action = raw_name + " 1"
        if min_distance_index == 1:
            best_action = raw_name + " 2"
        if min_distance_index == 2:
            best_action = "get dirty plate"
        # print('----', best_action)
        return best_action

    

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4:
        best_action = raw_name + " 2"

    if can_reach_1 != 4 and can_reach_2 == 4:
        best_action = raw_name + " 1"

    if can_reach_1 != 4 and can_reach_2 != 4:
        if distance_1 <= distance_2:
            best_action = raw_name + " 1"

        else:
            best_action = raw_name + " 2"
    
    return best_action



def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 20
    return env.reward



def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):

    agent_item = env.agent[1]
    human_agent = env.agent[0]


    counter1_x = 4
    counter1_y = 2

    counter2_x = 4
    counter2_y = 3

    counter3_x = 4
    counter3_y = 4

    counter4_x = 4
    counter4_y = 5

    counter5_x = 4
    counter5_y = 6

    counter6_x = 4
    counter6_y = 7

    counter7_x = 4
    counter7_y = 8
    
    counter8_x = 4
    counter8_y = 9

    counter9_x = 4
    counter9_y = 10

    counter10_x = 4
    counter10_y = 11

    counter11_x = 4
    counter11_y = 12

    counter12_x = 4
    counter12_y = 13


    counter1 = ITEMNAME[env.map[counter1_x][counter1_y]]
    counter2 = ITEMNAME[env.map[counter2_x][counter2_y]]
    counter3 = ITEMNAME[env.map[counter3_x][counter3_y]]
    counter4 = ITEMNAME[env.map[counter4_x][counter4_y]]
    counter5 = ITEMNAME[env.map[counter5_x][counter5_y]]
    counter6 = ITEMNAME[env.map[counter6_x][counter6_y]]
    counter7 = ITEMNAME[env.map[counter7_x][counter7_y]]
    counter8 = ITEMNAME[env.map[counter8_x][counter8_y]]
    counter9 = ITEMNAME[env.map[counter9_x][counter9_y]]
    counter10 = ITEMNAME[env.map[counter10_x][counter10_y]]
    counter11 = ITEMNAME[env.map[counter11_x][counter11_y]]
    counter12 = ITEMNAME[env.map[counter12_x][counter12_y]]



    reward_shaping_bonus = 0
    total_reward_bonus = 0


    reward_bonus_left = 0
    reward_bonus_right = 0


    """右侧high benevolence"""
    counters = [counter1, counter2, counter3, counter4, counter5, counter6, counter7, counter8, counter9, counter10, counter11, counter12]


    # if any(counter in ("lettuce") for counter in counters):
    #     best_action = intelligently_find_item_number(env, human_agent, "get lettuce")

    #     if firsttime_left_get_lettuce == True:
    #         reward_shaping_bonus = check_benevolence(env, best_action, action_left)
    #         if reward_shaping_bonus == 20:
    #             total_reward_bonus += reward_shaping_bonus
    #             reward_bonus_left = 20
    #             firsttime_left_get_lettuce = False



    if all(counter not in ("lettuce") for counter in counters):
        if not agent_item.holding:
            best_action = intelligently_find_item_number(env, agent_item, "get lettuce")
            if firsttime_right_go_to_knife == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 20:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_knife = False


        if agent_item.holding and isinstance(agent_item.holding, Lettuce):
            best_action = "go to counter"

            if firsttime_right_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 20:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_counter = False


            best_action = "go to counter"
            # reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            # if reward_shaping_bonus == 20:
            #     reward_bonus_left = 100


            if firsttime_left_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_left)
                if reward_shaping_bonus == 20:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_left = 100
                    firsttime_left_go_to_counter = False


    return reward_bonus_left, reward_bonus_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter



# ====== 环境与参数 ======
mac_env_id = 'Overcooked-MA-v2'
rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": -20,
    "pick up bad lettuce": -100
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]

env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "A_lowuncertainty",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': False,   # 并行采样时关闭 debug 能明显提速
}

# ====== SubprocVecEnv 工厂 ======
def make_env(rank: int):
    """
    每个子进程都会调用此函数，创建一套独立的：
    base env -> partner wrapper (agent_index=1) -> training wrapper (agent_index=0)
    """
    def _thunk():
        # base env（不要跨进程共享）
        base = gym.make(mac_env_id, **env_params)

        # partner env：只用于给已训练伙伴 PPO 绑定 env（SB3 的要求）
        partner_env = SingleAgentWrapper_Without_Latent(base, agent_index=1, other_agent_model=None)

        # 训练用 wrapper（会在内部 load 10 个伙伴模型，并“轮转采样”）
        wrapped = SingleAgentWrapper(
            base,                      # 原始 env
            partner_env,               # 仅作为 PPO.load 的绑定 env
            agent_index=0
        )

        # 子进程内单独设随机种子
        base.seed(SEED + rank)
        np.random.seed(SEED + rank)
        random.seed(SEED + rank)

        return wrapped
    return _thunk

# ====== Policy / PPO 参数 ======
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# 你的 ABIGatedExtractor 已经定义，这里直接用
policy_kwargs = dict(
    features_extractor_class=ABIGatedExtractorWithConf,
    features_extractor_kwargs=dict(base_dim=64),   # 如需更快：改成 32
    net_arch=[dict(pi=[128, 64], vf=[128, 64])]
)



# 注意：SB3 的 n_steps 是“每个环境”的 rollout 长度。
# 例如 n_steps=1024, n_envs=8 → 每个更新收集 8192 个样本。


ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 3600,
    'batch_size': 600,
    'n_epochs': 4,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    # 'ent_coef': 0.05,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)  # 关键！

    n_envs = 8

    # 子进程工厂（必须是可 picklable 的顶层函数/闭包）
    venv = SubprocVecEnv([make_env(i) for i in range(n_envs)])

    model = PPO("MlpPolicy", venv, policy_kwargs=policy_kwargs, **ppo_params)

    callback = EpisodeRewardCallback(
        save_path="final_trained_models/[MapA_lowuncertainty]POMDP",
        save_freq=100_000
    )

    def format_time(s): 
        m=int(s//60); ss=int(s%60); return f"{m}分{ss}秒"

    import time
    start = time.time()
    total_loops = 300
    steps_per_loop = 100_0000

    for i in range(1, total_loops + 1):
        model.learn(total_timesteps=steps_per_loop, callback=callback, reset_num_timesteps=False)
        print(f"[Loop {i}/{total_loops}] 已训练 {i*steps_per_loop:,} steps | 用时 {format_time(time.time()-start)}")

    print("✅ 训练完成")

