import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate
import random
import time
import torch
import os
import torch
import torch.nn as nn
from collections import deque
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


# ====== 全局随机种子 ======
SEED = 42  # 你可以修改这个数字来改变随机性

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# 增加 ABI 模型保存数据的函数
def save_dataset_to_disk(data_buffer, filepath="abi_training_data.pt"):
    torch.save(data_buffer, filepath)
    print(f"✅ Data buffer saved to {filepath}")

def load_dataset_from_disk(filepath="abi_training_data.pt"):
    return torch.load(filepath)




def get_boltzmann_probs_from_policy(model, obs, beta: float = 1.0):
    """
    输入 RL 模型、obs 和 beta，返回各动作的 Boltzmann 概率 (numpy array)。
    """
    obs_t, _ = model.policy.obs_to_tensor(obs)
    with torch.no_grad():
        dist = model.policy.get_distribution(obs_t)     # SB3 distribution
        logits = dist.distribution.logits               # [B, n_actions]
        scaled_logits = logits * beta                   # 控制理性程度
        probs = torch.softmax(scaled_logits, dim=-1).cpu().numpy()[0]

        # 打印（按概率降序）

        # action_probs = list(zip(macroActionName, probs))
        # action_probs.sort(key=lambda x: x[1], reverse=True)
        # print("Boltzmann Probabilities (beta={}):".format(beta))
        # for name, p in action_probs:
        #     print(f"{name:<20s} {p:.4f}")


    return probs

def sample_action_from_probs(probs: np.ndarray):
    """
    输入概率数组 (长度 = 动作数)，返回采样到的动作索引 (int)。
    """
    return int(np.random.choice(len(probs), p=probs))







# ===== 工具：Beta 与均匀先验的 KL，用于抑制过度自信 =====
def kl_beta_to_uniform(alpha, beta):
    # KL( Beta(a,b) || Beta(1,1) ) = -log B(a,b) + (a-1)psi(a) + (b-1)psi(b) - (a+b-2)psi(a+b)
    lgamma = torch.lgamma
    psi = torch.digamma
    a, b = alpha, beta
    return -(lgamma(a) + lgamma(b) - lgamma(a + b)) + (a - 1) * psi(a) + (b - 1) * psi(b) - (a + b - 2) * psi(a + b)

# ===== 二分类的 Evidential 损失：CE + λ*KL(Beta||Uniform) =====
def evidential_binary_loss(alpha, beta, y, lam=1e-3):
    # y ∈ {0,1}，形状 [B, 1]
    S = alpha + beta
    p = (alpha / S).clamp(1e-6, 1-1e-6)  # Beta 均值作为 P(y=1)
    ce = F.binary_cross_entropy(p, y, reduction='mean')
    kl = kl_beta_to_uniform(alpha, beta).mean()
    loss = ce + lam * kl
    metrics = {'ce': ce.item(), 'kl': kl.item(), 'S_mean': S.mean().item()}
    return loss, metrics

class LightweightTransformerABIModel_adaptivelength(nn.Module):
    """
    自适应长度 + 共享注意力；输出每个维度的 Beta 参数 (alpha, beta)。
    - 预测值：p = alpha / (alpha + beta)
    - 不确定性：S = alpha + beta（越小越不确定），或使用 Beta 方差。
    """
    def __init__(self, state_dim, hidden_dim=32, nhead=2, num_layers=1,
                 lr=1e-3, evidential_lambda=1e-3):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 共享注意力向量
        self.attn_vec = nn.Parameter(torch.randn(hidden_dim))

        # 三个维度的中间头（共享主干，不共享头部）
        self.head_A = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_B = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_I = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())

        # 每个维度输出 Beta 的 alpha / beta（softplus 保正）
        self.fc_A_alpha = nn.Linear(64, 1)
        self.fc_A_beta  = nn.Linear(64, 1)
        self.fc_B_alpha = nn.Linear(64, 1)
        self.fc_B_beta  = nn.Linear(64, 1)
        self.fc_I_alpha = nn.Linear(64, 1)
        self.fc_I_beta  = nn.Linear(64, 1)

        self.evidential_lambda = evidential_lambda
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.data_buffer = []

    @staticmethod
    def _make_pad_mask(lengths, max_len):
        device = lengths.device
        rng = torch.arange(max_len, device=device).unsqueeze(0)  # [1, T]
        pad_mask = rng >= lengths.unsqueeze(1)                   # [B, T], True=padding
        return pad_mask

    @staticmethod
    def _make_recent_window_mask(lengths, max_len, k_recent):
        device = lengths.device
        k = torch.clamp(torch.as_tensor(k_recent, device=device), min=1)
        start = torch.clamp(lengths - k, min=0)                          # [B]
        idx   = torch.arange(max_len, device=device).unsqueeze(0)        # [1, T]
        # 窗口内为 False（保留），其他 True（遮蔽）
        in_window = (idx >= start.unsqueeze(1)) & (idx < lengths.unsqueeze(1))
        mask = ~in_window                                                 # [B, T]
        return mask

    def _attn_pool(self, h, attn_vec, mask_bool, lengths):
        """
        h: [B, T, H]
        attn_vec: [H]
        mask_bool: [B, T] True=mask
        lengths: [B]
        """
        B, T, H = h.shape
        logits = torch.matmul(h, attn_vec)                 # [B, T]
        logits = logits.masked_fill(mask_bool, float('-inf'))

        # 兜底：若某一行全被 mask，则取“最后一个有效 token”
        all_masked = torch.isinf(logits).all(dim=1)        # [B]
        if all_masked.any():
            weights = torch.zeros_like(logits)
            last_idx = (lengths - 1).clamp(min=0)          # [B]
            weights[torch.arange(B, device=h.device), last_idx] = 1.0
        else:
            weights = torch.softmax(logits, dim=1)         # [B, T]

        pooled = torch.sum(h * weights.unsqueeze(-1), dim=1)  # [B, H]
        return pooled

    def forward(self, x, lengths=None, k_ability=15, k_benevolence=30, k_integrity=30):
        """
        x: [B, T, state_dim]（可包含 padding）
        lengths: [B] 每个序列真实长度；若为 None，默认全长有效
        返回：
          (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I) 形状均为 [B, 1]
        """
        B, T, _ = x.shape
        device = x.device
        if lengths is None:
            lengths = torch.full((B,), T, dtype=torch.long, device=device)

        # 1) 输入投影
        x_proj = self.input_proj(x)  # [B, T, H]

        # 2) Transformer 编码（带 padding mask）
        pad_mask = self._make_pad_mask(lengths, T)  # [B, T], True=padding
        h = self.transformer(x_proj, src_key_padding_mask=pad_mask)  # [B, T, H]

        # 3) 三个指标各自“最近 K 步”窗口（与 pad_mask 叠加）
        mask_A = pad_mask | self._make_recent_window_mask(lengths, T, k_ability)
        mask_B = pad_mask | self._make_recent_window_mask(lengths, T, k_benevolence)
        mask_I = pad_mask | self._make_recent_window_mask(lengths, T, k_integrity)

        # 4) 共享注意力向量做池化
        pooled_A = self._attn_pool(h, self.attn_vec, mask_A, lengths)  # [B, H]
        pooled_B = self._attn_pool(h, self.attn_vec, mask_B, lengths)
        pooled_I = self._attn_pool(h, self.attn_vec, mask_I, lengths)

        # 5) 维度头部
        zA = self.head_A(pooled_A)
        zB = self.head_B(pooled_B)
        zI = self.head_I(pooled_I)

        eps = 1e-4
        alpha_A = F.softplus(self.fc_A_alpha(zA)) + eps
        beta_A  = F.softplus(self.fc_A_beta(zA))  + eps
        alpha_B = F.softplus(self.fc_B_alpha(zB)) + eps
        beta_B  = F.softplus(self.fc_B_beta(zB))  + eps
        alpha_I = F.softplus(self.fc_I_alpha(zI)) + eps
        beta_I  = F.softplus(self.fc_I_beta(zI))  + eps

        return (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I)

    @staticmethod
    def beta_mean_strength(alpha, beta):
        S = alpha + beta
        p = (alpha / S).clamp(1e-6, 1-1e-6)
        return p, S

    def store_data(self, state_history, ability, benevolence, integrity):
        if isinstance(state_history, torch.Tensor):
            sh = state_history.detach().cpu()
        else:
            sh = torch.as_tensor(state_history)
        self.data_buffer.append((sh, float(ability), float(benevolence), float(integrity)))

    def _collate_batch(self, batch):
        seqs, A, B, I = [], [], [], []
        for sh, a, b, i in batch:
            t = torch.as_tensor(sh, dtype=torch.float32)
            seqs.append(t)
            A.append([a]); B.append([b]); I.append([i])
        lengths = torch.tensor([s.shape[0] for s in seqs], dtype=torch.long)
        states = pad_sequence(seqs, batch_first=True)  # [B, T_max, D]
        abilitys     = torch.tensor(A, dtype=torch.float32)
        benevolences = torch.tensor(B, dtype=torch.float32)
        integritys   = torch.tensor(I, dtype=torch.float32)
        return states, lengths, abilitys, benevolences, integritys

    def train_step(self, batch_size=None, k_ability=15, k_benevolence=30, k_integrity=30,
                   device=None, loss_weights=(1.0, 1.0, 1.0)):
        if not self.data_buffer:
            return None

        batch = self.data_buffer if batch_size is None else self.data_buffer[:batch_size]
        states, lengths, abilitys, benevolences, integritys = self._collate_batch(batch)

        if device is None:
            device = next(self.parameters()).device
        states       = states.to(device)
        lengths      = lengths.to(device)
        abilitys     = abilitys.to(device)
        benevolences = benevolences.to(device)
        integritys   = integritys.to(device)

        self.optimizer.zero_grad()
        (aA, bA), (aB, bB), (aI, bI) = self.forward(
            states, lengths=lengths,
            k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
        )

        # Evidential 损失（A/B/I 各自）
        lam = self.evidential_lambda
        wA, wB, wI = loss_weights
        loss_A, mA = evidential_binary_loss(aA, bA, abilitys,     lam=lam)
        loss_B, mB = evidential_binary_loss(aB, bB, benevolences, lam=lam)
        loss_I, mI = evidential_binary_loss(aI, bI, integritys,   lam=lam)

        total = wA*loss_A + wB*loss_B + wI*loss_I
        total.backward()
        self.optimizer.step()

        return {
            'total_loss': total.item(),
            'A_ce': mA['ce'], 'A_kl': mA['kl'], 'A_S_mean': mA['S_mean'],
            'B_ce': mB['ce'], 'B_kl': mB['kl'], 'B_S_mean': mB['S_mean'],
            'I_ce': mI['ce'], 'I_kl': mI['kl'], 'I_S_mean': mI['S_mean'],
        }

    def predict_beta_params(self, states_np, lengths_np=None,
                            k_ability=15, k_benevolence=30, k_integrity=30, device=None):
        """
        推理便捷接口：返回每维的 alpha/beta 及 p/S
        states_np: [B, T, D] (np 或 tensor)
        lengths_np: [B] (可选；不传则默认全长有效)
        """
        if not torch.is_tensor(states_np):
            states = torch.tensor(states_np, dtype=torch.float32)
        else:
            states = states_np.float()
        B, T, _ = states.shape

        if lengths_np is None:
            lengths = torch.full((B,), T, dtype=torch.long)
        else:
            lengths = torch.tensor(lengths_np, dtype=torch.long)

        if device is None:
            device = next(self.parameters()).device
        states  = states.to(device)
        lengths = lengths.to(device)

        with torch.no_grad():
            (aA, bA), (aB, bB), (aI, bI) = self.forward(
                states, lengths=lengths,
                k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
            )
            pA, SA = self.beta_mean_strength(aA, bA)
            pB, SB = self.beta_mean_strength(aB, bB)
            pI, SI = self.beta_mean_strength(aI, bI)

        out = {
            'A': {'alpha': aA.squeeze(-1), 'beta': bA.squeeze(-1), 'p': pA.squeeze(-1), 'S': SA.squeeze(-1)},
            'B': {'alpha': aB.squeeze(-1), 'beta': bB.squeeze(-1), 'p': pB.squeeze(-1), 'S': SB.squeeze(-1)},
            'I': {'alpha': aI.squeeze(-1), 'beta': bI.squeeze(-1), 'p': pI.squeeze(-1), 'S': SI.squeeze(-1)},
        }
        return out

    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'evidential_lambda': self.evidential_lambda
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path, map_location=None):
        checkpoint = torch.load(path, map_location=map_location)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.evidential_lambda = checkpoint.get('evidential_lambda', self.evidential_lambda)
        print(f"Model loaded from {path}")




class LightweightTransformerABIModel(nn.Module):
    def __init__(self, state_dim, seq_len=15, hidden_dim=32, nhead=2, num_layers=1):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'  # 更轻量
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling
        self.attention_vector = nn.Parameter(torch.randn(hidden_dim))  # 可学习注意力向量

        self.fc_ability = nn.Linear(hidden_dim, 1)
        self.fc_benevolence = nn.Linear(hidden_dim, 1)
        self.fc_integrity = nn.Linear(hidden_dim, 1)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.data_buffer = []

    def forward(self, x):
        # x: [batch, seq_len, state_dim]
        x = self.input_proj(x)           # [B, T, H]
        h = self.transformer(x)          # [B, T, H]

        # Attention pooling
        attn_weights = torch.softmax(torch.matmul(h, self.attention_vector), dim=1)  # [B, T]
        pooled = torch.sum(h * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        ability     = torch.sigmoid(self.fc_ability(pooled))
        benevolence = torch.sigmoid(self.fc_benevolence(pooled))
        integrity   = torch.sigmoid(self.fc_integrity(pooled))
        return ability, benevolence, integrity

    def store_data(self, state_history, ability, benevolence, integrity):
        self.data_buffer.append((state_history, ability, benevolence, integrity))


    def train_step(self):
        # print(self.data_buffer)
        if not self.data_buffer:
            # print('here')
            return None

        states, abilitys, benevolences, integritys = zip(*self.data_buffer)
        states = torch.tensor(states, dtype=torch.float32)
        abilitys = torch.tensor(abilitys, dtype=torch.float32).unsqueeze(-1)
        benevolences = torch.tensor(benevolences, dtype=torch.float32).unsqueeze(-1)
        integritys = torch.tensor(integritys, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_ability, pred_benevolence, pred_integrity = self.forward(states)

        # 分别计算各自 loss
        loss_ability = self.criterion(pred_ability, abilitys)
        loss_benevolence = self.criterion(pred_benevolence, benevolences)
        loss_integrity = self.criterion(pred_integrity, integritys)

        total_loss = loss_ability + loss_benevolence + loss_integrity
        total_loss.backward()
        self.optimizer.step()

        # self.data_buffer = []

        # 你可以返回一个 dict，方便调试记录
        return {
            'total_loss': total_loss.item(),
            'loss_ability': loss_ability.item(),
            'loss_benevolence': loss_benevolence.item(),
            'loss_integrity': loss_integrity.item()
        }



    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {path}")




















class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]


        """下面代码是进行【合作】的reward shaping，以及对inconsistency进行惩罚的部分，暂时注释掉"""
        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        # benevolence_reward, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter = check_action_benevolence(self.env, action_left, action_right, self.firsttime_left_go_to_counter, self.firsttime_left_get_lettuce, self.firsttime_right_go_to_knife, self.firsttime_right_go_to_counter)



        primary_actions = self.env._computeLowLevelActions(actions)

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info





ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


def find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3):
    reachable_indices = []
    distances = []
    
    # 收集可达的索引和对应距离
    if can_reach_1 != 4:
        reachable_indices.append(0)
        distances.append(distance_1)
    if can_reach_2 != 4:
        reachable_indices.append(1)
        distances.append(distance_2)
    if can_reach_3 != 4:
        reachable_indices.append(2)
        distances.append(distance_3)
    
    if not reachable_indices:
        return False  # 没有可达的位置
    
    # 如果只有一个可达位置，直接返回该索引
    if len(reachable_indices) == 1:
        return reachable_indices[0]
    
    # 否则，返回距离最小的索引
    min_distance_index = reachable_indices[distances.index(min(distances))]
    return min_distance_index


def intelligently_find_item_number(env, agent_item, raw_name):

    if raw_name == "get plate":
        target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
        can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
        distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

        target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
        can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
        distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

        target_x_3, target_y_3 = env._findPOitem(agent_item, macroActionDict["get dirty plate"])
        can_reach_3 = env._navigate(agent_item, target_x_3, target_y_3)
        distance_3 = env._calDistance(target_x_3, target_y_3, agent_item.x, agent_item.y)


        best_action = "stay"

        min_distance_index = find_best_reachable_index(can_reach_1, can_reach_2, can_reach_3, distance_1, distance_2, distance_3)

        if min_distance_index == 0:
            best_action = raw_name + " 1"
        if min_distance_index == 1:
            best_action = raw_name + " 2"
        if min_distance_index == 2:
            best_action = "get dirty plate"
        # print('----', best_action)
        return best_action

    

    target_x_1, target_y_1 = env._findPOitem(agent_item, macroActionDict[raw_name + " 1"])
    can_reach_1 = env._navigate(agent_item, target_x_1, target_y_1)
    distance_1 = env._calDistance(target_x_1, target_y_1, agent_item.x, agent_item.y)

    target_x_2, target_y_2 = env._findPOitem(agent_item, macroActionDict[raw_name + " 2"])
    can_reach_2 = env._navigate(agent_item, target_x_2, target_y_2)
    distance_2 = env._calDistance(target_x_2, target_y_2, agent_item.x, agent_item.y)

    best_action = "stay"
    if can_reach_1 == 4 and can_reach_2 != 4:
        best_action = raw_name + " 2"

    if can_reach_1 != 4 and can_reach_2 == 4:
        best_action = raw_name + " 1"

    if can_reach_1 != 4 and can_reach_2 != 4:
        if distance_1 <= distance_2:
            best_action = raw_name + " 1"

        else:
            best_action = raw_name + " 2"
    
    return best_action



def check_benevolence(env, best_action, action):
    env.reward = 0
    if action == macroActionDict[best_action] and macroActionDict[best_action] != 0:
        env.reward += 50
    return env.reward



def check_action_benevolence(env, action_left, action_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter):

    agent_item = env.agent[1]
    human_agent = env.agent[0]

    counter1_x = 12
    counter1_y = 7

    counter2_x = 13
    counter2_y = 7


    counter1 = ITEMNAME[env.map[counter1_x][counter1_y]]
    counter2 = ITEMNAME[env.map[counter2_x][counter2_y]]


    reward_shaping_bonus = 0

    total_reward_bonus = 0

    reward_bonus_left = 0
    reward_bonus_right = 0

    """右侧high benevolence"""
    counters = [counter1, counter2]
    # if any(counter in ("lettuce") for counter in counters):
    #     best_action = intelligently_find_item_number(env, human_agent, "get lettuce")

    #     if firsttime_left_get_lettuce == True:
    #         reward_shaping_bonus = check_benevolence(env, best_action, action_left)
    #         if reward_shaping_bonus == 50:
    #             total_reward_bonus += reward_shaping_bonus
    #             reward_bonus_left = 20
    #             firsttime_left_get_lettuce = False



    if all(counter not in ("lettuce") for counter in counters):
        if not agent_item.holding:
            best_action = intelligently_find_item_number(env, agent_item, "get lettuce")
            if firsttime_right_go_to_knife == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_knife = False


        if agent_item.holding and isinstance(agent_item.holding, Lettuce):
            best_action = "go to counter"

            if firsttime_right_go_to_counter == True:
                reward_shaping_bonus = check_benevolence(env, best_action, action_right)
                if reward_shaping_bonus == 50:
                    total_reward_bonus += reward_shaping_bonus
                    reward_bonus_right = 20
                    firsttime_right_go_to_counter = False


        best_action = "go to counter"
        # reward_shaping_bonus = check_benevolence(env, best_action, action_left)
        # if reward_shaping_bonus == 50:
        #     reward_bonus_left = 100


        if firsttime_left_go_to_counter == True:
            reward_shaping_bonus = check_benevolence(env, best_action, action_left)
            if reward_shaping_bonus == 50:
                total_reward_bonus += reward_shaping_bonus
                reward_bonus_left = 1000
                firsttime_left_go_to_counter = False


    return reward_bonus_left, reward_bonus_right, firsttime_left_go_to_counter, firsttime_left_get_lettuce, firsttime_right_go_to_knife, firsttime_right_go_to_counter




rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": -20,
    "pick up bad lettuce": -1
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": -20,
    "pick up bad lettuce": -1
}]



mac_env_id = 'Overcooked-MA-v0'
env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "A",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': True
}


# Initialize shared environment
shared_env = gym.make(mac_env_id, **env_params)




# Wrap each agent
env_agent_0 = SingleAgentWrapper(shared_env, agent_index=0)
env_agent_1 = SingleAgentWrapper(shared_env, agent_index=1)




















"""帮用干净"""

high_high_agent0_version2 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_highI_version1_agent0/model_4100000", env=env_agent_0)
high_high_agent1_version2 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_highI_version1_agent1/model_4100000", env=env_agent_1)

high_high_agent0_version3 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_highI_version2_agent0/model_4100000", env=env_agent_0)
high_high_agent1_version3 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_highI_version2_agent1/model_4100000", env=env_agent_0)

"""帮用脏"""
high_low_agent0 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_lowI_agent0/model_4100000", env=env_agent_0)
high_low_agent1 = PPO.load("final_trained_models/[MapA]trustee_agent_highB_lowI_agent1/model_4100000", env=env_agent_0)

"""不帮用干净"""
low_high_agent0 = PPO.load("final_trained_models/[MapA]trustee_agent_lowB_highI_agent0/model_4100000", env=env_agent_0)
low_high_agent1 = PPO.load("final_trained_models/[MapA]trustee_agent_lowB_highI_agent1/model_4100000", env=env_agent_0)

"""不帮用脏"""
low_low_agent0 = PPO.load("final_trained_models/[MapA]trustee_agent_lowB_lowI_agent0/model_4100000", env=env_agent_0)
low_low_agent1 = PPO.load("final_trained_models/[MapA]trustee_agent_lowB_lowI_agent1/model_4100000", env=env_agent_0)




# Test the trained models
obs = shared_env.reset()


# Test the trained models
obs = shared_env.reset()


# Create a list to store frames for the video
frames = []

previous_obs = []

reward_this = 0

firsttime_left_go_to_counter = True
firsttime_left_get_lettuce = True
firsttime_right_go_to_knife = True
firsttime_right_go_to_counter = True


model_0_list = [high_high_agent0_version2, high_high_agent0_version3, high_low_agent0, low_high_agent0, low_low_agent0, high_high_agent0_version2, high_high_agent0_version3, high_low_agent0, low_high_agent0, low_low_agent0]
model_1_list = [high_high_agent1_version2, high_high_agent1_version3, high_low_agent1, low_high_agent1, low_low_agent1, high_high_agent1_version2, high_high_agent1_version3, high_low_agent1, low_high_agent1, low_low_agent1]

ability_list = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
benevolence_list = [1, 1, 1, 0, 0, 1, 1, 1, 0, 0,]
integrity_list = [1, 1, 0, 1, 0, 1, 1, 0, 1, 0]



abi_model =LightweightTransformerABIModel_adaptivelength(6)


# abi_model = LightweightTransformerABIModel(5)



for round in range(10):
    print(round)
    # ==== 数据收集 ====
    for i in range(len(model_0_list)):
        model_agent_0 = model_0_list[i]
        model_agent_1 = model_1_list[i]

        partner_ability = ability_list[i]
        partner_benevolence = benevolence_list[i]
        partner_integrity = integrity_list[i]

        obs = shared_env.reset()
        cur_episode_done = False
        state_history = deque(maxlen=30)

        action_probabilities = []

        while not cur_episode_done:



            



            actions = [0, 0]

            boltzmann_actions = [0, 0]


            
            action, _ = model_agent_0.predict(obs[0])
            other_agent_action, _ = model_agent_1.predict(obs[1])


            actions[0] = action
            actions[1] = other_agent_action


            """只有切换high level action的时候才计算一次probability"""
            """注意，需要在执行_computeLowLevelActions之前就判断是否self.env.macroAgent[1].cur_macro_action_done == True"""
            if partner_ability == 0 and shared_env.macroAgent[1].cur_macro_action_done == True:
                # print('更新概率了')
                action_probabilities = get_boltzmann_probs_from_policy(model_agent_1, obs[1], beta = 0.3)




            primary_actions, real_execute_macro_actions = shared_env._computeLowLevelActions(actions)

            low_level_action_to_execute = primary_actions



            """只有切换high level action的时候才计算一次probability"""
            if partner_ability == 0 and shared_env.macroAgent[1].cur_macro_action_done != True:
                other_agent_action_boltzmaan = sample_action_from_probs(action_probabilities)

                boltzmann_actions[0] = action
                boltzmann_actions[1] = other_agent_action_boltzmaan


                # print('加了玻尔兹曼之后的', boltzmann_actions)


                primary_actions_boltzmaan, real_execute_macro_actions = shared_env._computeLowLevelActions_boltzmaanlowlevel(boltzmann_actions)

                low_level_action_to_execute = [primary_actions[0], primary_actions_boltzmaan[1]]
            

            previous_holding = shared_env.agent[1].holding
            obs, _, cur_episode_done, _ = shared_env.step(low_level_action_to_execute)
            obs = shared_env._get_macro_obs()
            after_holding = shared_env.agent[1].holding

            partner_obs = shared_env._get_macro_vector_obs_for_ABImodel()
            state_history.append(partner_obs[1])

            partner_key_event_happen = (
                previous_holding is not None and after_holding is None
            )

            if len(state_history) == 30 and partner_key_event_happen:
                abi_model.store_data(
                    list(state_history),
                    partner_ability,
                    partner_benevolence,
                    partner_integrity
                )



# # ==== 保存数据到本地 ====
# save_dataset_to_disk(abi_model.data_buffer, filepath="[NewENV][MapA]abi_training_data_Boltzmaan.pt")

# # ==== 加载数据再训练 ====
# abi_model.data_buffer = load_dataset_from_disk("[NewENV][MapA]abi_training_data_Boltzmaan.pt")
# # print(abi_model.data_buffer)


# === 设置输出目录 ===
output_dir = "[MapA]abi_model_outputs_Boltzmaan_10times_adaptivelength_FFFFFFFFFFFFFFINAL_with_uncertainty"
os.makedirs(output_dir, exist_ok=True)  # 若文件夹不存在则创建

# === 训练阶段 ===
num_epochs = 10000
log_every = 1000

# 记录：总损失、各维 CE、各维 S_mean
loss_log = {
    'total': [],
    'A_ce': [], 'B_ce': [], 'I_ce': [],
    'A_S_mean': [], 'B_S_mean': [], 'I_S_mean': []
}

for epoch in range(1, num_epochs + 1):
    loss_info = abi_model.train_step(
        k_ability=15, k_benevolence=30, k_integrity=30
    )

    # 紧急兜底（无数据时 train_step 可能返回 None）
    if loss_info is None:
        continue

    # 记录 loss / metrics
    loss_log['total'].append(loss_info['total_loss'])
    loss_log['A_ce'].append(loss_info['A_ce'])
    loss_log['B_ce'].append(loss_info['B_ce'])
    loss_log['I_ce'].append(loss_info['I_ce'])
    loss_log['A_S_mean'].append(loss_info['A_S_mean'])
    loss_log['B_S_mean'].append(loss_info['B_S_mean'])
    loss_log['I_S_mean'].append(loss_info['I_S_mean'])

    if epoch % log_every == 0:
        print(f"📈 Epoch {epoch} | "
              f"total={loss_info['total_loss']:.4f} | "
              f"A_ce={loss_info['A_ce']:.4f}, B_ce={loss_info['B_ce']:.4f}, I_ce={loss_info['I_ce']:.4f} | "
              f"A_S={loss_info['A_S_mean']:.2f}, B_S={loss_info['B_S_mean']:.2f}, I_S={loss_info['I_S_mean']:.2f}"
        )

        # === 保存模型 ===
        model_path = os.path.join(output_dir, f"abi_model_step_{epoch}.pt")
        abi_model.save_model(model_path)

        # === 绘制并保存 CE 曲线（分类误差） ===
        plt.figure(figsize=(8, 5))
        plt.plot(loss_log['A_ce'], label="Ability CE")
        plt.plot(loss_log['B_ce'], label="Benevolence CE")
        plt.plot(loss_log['I_ce'], label="Integrity CE")
        plt.title(f"CE Curves up to Epoch {epoch}")
        plt.xlabel("Epoch")
        plt.ylabel("Cross-Entropy")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        fig_path_ce = os.path.join(output_dir, f"abi_ce_curve_epoch_{epoch}.png")
        plt.savefig(fig_path_ce)
        plt.close()

        # === 绘制并保存 S_mean 曲线（不确定性强度；越大越自信） ===
        plt.figure(figsize=(8, 5))
        plt.plot(loss_log['A_S_mean'], label="Ability S_mean")
        plt.plot(loss_log['B_S_mean'], label="Benevolence S_mean")
        plt.plot(loss_log['I_S_mean'], label="Integrity S_mean")
        plt.title(f"S_mean Curves up to Epoch {epoch}")
        plt.xlabel("Epoch")
        plt.ylabel("S_mean (alpha+beta)")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        fig_path_s = os.path.join(output_dir, f"abi_Smean_curve_epoch_{epoch}.png")
        plt.savefig(fig_path_s)
        plt.close()