import gym
from stable_baselines3 import PPO
import numpy as np
from gym_macro_overcooked.overcooked_V1 import Overcooked_V1
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure
import matplotlib.pyplot as plt
from gym_macro_overcooked.items import Tomato, Lettuce, Onion, Plate, Knife, Delivery, Agent, Food, DirtyPlate
import random
import time
import torch
import os
import torch
import torch.nn as nn
from collections import deque
import math

import matplotlib
matplotlib.use("Agg")
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


# train_parallel.py
import os
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

import gym
import numpy as np
import random
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from collections import deque
import time
import matplotlib.pyplot as plt


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


torch.set_num_threads(1)
torch.set_num_interop_threads(1)


from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class ABIGatedExtractorWithConf(BaseFeaturesExtractor):

    def __init__(self, observation_space: gym.spaces.Box, base_dim: int = 64, conf_power: float = 1.0):
        self.obs_dim = observation_space.shape[0]
        
        assert self.obs_dim >= 6, "obs needs more than 6 elements: [A,B,I,conf_A,conf_B,conf_I]"
        super().__init__(observation_space, features_dim=6 * base_dim + 6)
        # assert self.obs_dim >= 6,
        # super().__init__(observation_space, features_dim=6 * base_dim + 6)

        self.base_dim = base_dim
        self.conf_power = conf_power

        self.backbone = nn.Sequential(
            nn.Linear(self.obs_dim - 6, 256), nn.ReLU(),
            nn.Linear(256, base_dim), nn.ReLU()
        )

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        x = obs[..., : self.obs_dim - 6]
        A = obs[..., self.obs_dim - 6 : self.obs_dim - 5]
        B = obs[..., self.obs_dim - 5 : self.obs_dim - 4]
        I = obs[..., self.obs_dim - 4 : self.obs_dim - 3]
        cA = obs[..., self.obs_dim - 3 : self.obs_dim - 2].clamp(0.0, 1.0)
        cB = obs[..., self.obs_dim - 2 : self.obs_dim - 1].clamp(0.0, 1.0)
        cI = obs[..., self.obs_dim - 1 : self.obs_dim     ].clamp(0.0, 1.0)

        # shared feature
        f = self.backbone(x)  # [B, D]


        A_pos, A_neg = torch.relu(A), torch.relu(-A)
        B_pos, B_neg = torch.relu(B), torch.relu(-B)
        I_pos, I_neg = torch.relu(I), torch.relu(-I)

        gate_A_pos = A_pos
        gate_A_neg = A_neg
        gate_B_pos = B_pos
        gate_B_neg = B_neg
        gate_I_pos = I_pos
        gate_I_neg = I_neg


        out = torch.cat([
            f * gate_A_pos, f * gate_A_neg,
            f * gate_B_pos, f * gate_B_neg,
            f * gate_I_pos, f * gate_I_neg,
            A, B, I,       
            cA, cB, cI     
        ], dim=-1)
        return out
    



def save_dataset_to_disk(data_buffer, filepath="abi_training_data.pt"):
    torch.save(data_buffer, filepath)
    print(f"✅ Data buffer saved to {filepath}")

def load_dataset_from_disk(filepath="abi_training_data.pt"):
    return torch.load(filepath)



def kl_beta_to_uniform(alpha, beta):
    # KL( Beta(a,b) || Beta(1,1) ) = -log B(a,b) + (a-1)psi(a) + (b-1)psi(b) - (a+b-2)psi(a+b)
    lgamma = torch.lgamma
    psi = torch.digamma
    a, b = alpha, beta
    return -(lgamma(a) + lgamma(b) - lgamma(a + b)) + (a - 1) * psi(a) + (b - 1) * psi(b) - (a + b - 2) * psi(a + b)


def evidential_binary_loss(alpha, beta, y, lam=1e-3):
    S = alpha + beta
    p = (alpha / S).clamp(1e-6, 1-1e-6)
    ce = F.binary_cross_entropy(p, y, reduction='mean')
    kl = kl_beta_to_uniform(alpha, beta).mean()
    loss = ce + lam * kl
    metrics = {'ce': ce.item(), 'kl': kl.item(), 'S_mean': S.mean().item()}
    return loss, metrics


class LightweightTransformerABIModel_adaptivelength(nn.Module):

    def __init__(self, state_dim, hidden_dim=32, nhead=2, num_layers=1,
                 lr=1e-3, evidential_lambda=1e-3):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.attn_vec = nn.Parameter(torch.randn(hidden_dim))

        self.head_A = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_B = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())
        self.head_I = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU())

        self.fc_A_alpha = nn.Linear(64, 1)
        self.fc_A_beta  = nn.Linear(64, 1)
        self.fc_B_alpha = nn.Linear(64, 1)
        self.fc_B_beta  = nn.Linear(64, 1)
        self.fc_I_alpha = nn.Linear(64, 1)
        self.fc_I_beta  = nn.Linear(64, 1)

        self.evidential_lambda = evidential_lambda
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.data_buffer = []

    @staticmethod
    def _make_pad_mask(lengths, max_len):
        device = lengths.device
        rng = torch.arange(max_len, device=device).unsqueeze(0)
        pad_mask = rng >= lengths.unsqueeze(1)
        return pad_mask

    @staticmethod
    def _make_recent_window_mask(lengths, max_len, k_recent):
        device = lengths.device
        k = torch.clamp(torch.as_tensor(k_recent, device=device), min=1)
        start = torch.clamp(lengths - k, min=0)                          # [B]
        idx   = torch.arange(max_len, device=device).unsqueeze(0)        # [1, T]

        in_window = (idx >= start.unsqueeze(1)) & (idx < lengths.unsqueeze(1))
        mask = ~in_window                                                 # [B, T]
        return mask

    def _attn_pool(self, h, attn_vec, mask_bool, lengths):
        """
        h: [B, T, H]
        attn_vec: [H]
        mask_bool: [B, T] True=mask
        lengths: [B]
        """
        B, T, H = h.shape
        logits = torch.matmul(h, attn_vec)                 # [B, T]
        logits = logits.masked_fill(mask_bool, float('-inf'))

        all_masked = torch.isinf(logits).all(dim=1)        # [B]
        if all_masked.any():
            weights = torch.zeros_like(logits)
            last_idx = (lengths - 1).clamp(min=0)          # [B]
            weights[torch.arange(B, device=h.device), last_idx] = 1.0
        else:
            weights = torch.softmax(logits, dim=1)         # [B, T]

        pooled = torch.sum(h * weights.unsqueeze(-1), dim=1)  # [B, H]
        return pooled

    def forward(self, x, lengths=None, k_ability=15, k_benevolence=30, k_integrity=30):

        B, T, _ = x.shape
        device = x.device
        if lengths is None:
            lengths = torch.full((B,), T, dtype=torch.long, device=device)


        x_proj = self.input_proj(x)  # [B, T, H]


        pad_mask = self._make_pad_mask(lengths, T)  # [B, T], True=padding
        h = self.transformer(x_proj, src_key_padding_mask=pad_mask)  # [B, T, H]


        mask_A = pad_mask | self._make_recent_window_mask(lengths, T, k_ability)
        mask_B = pad_mask | self._make_recent_window_mask(lengths, T, k_benevolence)
        mask_I = pad_mask | self._make_recent_window_mask(lengths, T, k_integrity)


        pooled_A = self._attn_pool(h, self.attn_vec, mask_A, lengths)  # [B, H]
        pooled_B = self._attn_pool(h, self.attn_vec, mask_B, lengths)
        pooled_I = self._attn_pool(h, self.attn_vec, mask_I, lengths)


        zA = self.head_A(pooled_A)
        zB = self.head_B(pooled_B)
        zI = self.head_I(pooled_I)

        eps = 1e-4
        alpha_A = F.softplus(self.fc_A_alpha(zA)) + eps
        beta_A  = F.softplus(self.fc_A_beta(zA))  + eps
        alpha_B = F.softplus(self.fc_B_alpha(zB)) + eps
        beta_B  = F.softplus(self.fc_B_beta(zB))  + eps
        alpha_I = F.softplus(self.fc_I_alpha(zI)) + eps
        beta_I  = F.softplus(self.fc_I_beta(zI))  + eps

        return (alpha_A, beta_A), (alpha_B, beta_B), (alpha_I, beta_I)

    @staticmethod
    def beta_mean_strength(alpha, beta):
        S = alpha + beta
        p = (alpha / S).clamp(1e-6, 1-1e-6)
        return p, S

    def store_data(self, state_history, ability, benevolence, integrity):
        if isinstance(state_history, torch.Tensor):
            sh = state_history.detach().cpu()
        else:
            sh = torch.as_tensor(state_history)
        self.data_buffer.append((sh, float(ability), float(benevolence), float(integrity)))

    def _collate_batch(self, batch):
        seqs, A, B, I = [], [], [], []
        for sh, a, b, i in batch:
            t = torch.as_tensor(sh, dtype=torch.float32)
            seqs.append(t)
            A.append([a]); B.append([b]); I.append([i])
        lengths = torch.tensor([s.shape[0] for s in seqs], dtype=torch.long)
        states = pad_sequence(seqs, batch_first=True)  # [B, T_max, D]
        abilitys     = torch.tensor(A, dtype=torch.float32)
        benevolences = torch.tensor(B, dtype=torch.float32)
        integritys   = torch.tensor(I, dtype=torch.float32)
        return states, lengths, abilitys, benevolences, integritys

    def train_step(self, batch_size=None, k_ability=15, k_benevolence=30, k_integrity=30,
                   device=None, loss_weights=(1.0, 1.0, 1.0)):
        if not self.data_buffer:
            return None

        batch = self.data_buffer if batch_size is None else self.data_buffer[:batch_size]
        states, lengths, abilitys, benevolences, integritys = self._collate_batch(batch)

        if device is None:
            device = next(self.parameters()).device
        states       = states.to(device)
        lengths      = lengths.to(device)
        abilitys     = abilitys.to(device)
        benevolences = benevolences.to(device)
        integritys   = integritys.to(device)

        self.optimizer.zero_grad()
        (aA, bA), (aB, bB), (aI, bI) = self.forward(
            states, lengths=lengths,
            k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
        )


        lam = self.evidential_lambda
        wA, wB, wI = loss_weights
        loss_A, mA = evidential_binary_loss(aA, bA, abilitys,     lam=lam)
        loss_B, mB = evidential_binary_loss(aB, bB, benevolences, lam=lam)
        loss_I, mI = evidential_binary_loss(aI, bI, integritys,   lam=lam)

        total = wA*loss_A + wB*loss_B + wI*loss_I
        total.backward()
        self.optimizer.step()

        return {
            'total_loss': total.item(),
            'A_ce': mA['ce'], 'A_kl': mA['kl'], 'A_S_mean': mA['S_mean'],
            'B_ce': mB['ce'], 'B_kl': mB['kl'], 'B_S_mean': mB['S_mean'],
            'I_ce': mI['ce'], 'I_kl': mI['kl'], 'I_S_mean': mI['S_mean'],
        }

    def predict_beta_params(self, states_np, lengths_np=None,
                            k_ability=15, k_benevolence=30, k_integrity=30, device=None):
        if not torch.is_tensor(states_np):
            states = torch.tensor(states_np, dtype=torch.float32)
        else:
            states = states_np.float()
        B, T, _ = states.shape

        if lengths_np is None:
            lengths = torch.full((B,), T, dtype=torch.long)
        else:
            lengths = torch.tensor(lengths_np, dtype=torch.long)

        if device is None:
            device = next(self.parameters()).device
        states  = states.to(device)
        lengths = lengths.to(device)

        with torch.no_grad():
            (aA, bA), (aB, bB), (aI, bI) = self.forward(
                states, lengths=lengths,
                k_ability=k_ability, k_benevolence=k_benevolence, k_integrity=k_integrity
            )
            pA, SA = self.beta_mean_strength(aA, bA)
            pB, SB = self.beta_mean_strength(aB, bB)
            pI, SI = self.beta_mean_strength(aI, bI)

        out = {
            'A': {'alpha': aA.squeeze(-1), 'beta': bA.squeeze(-1), 'p': pA.squeeze(-1), 'S': SA.squeeze(-1)},
            'B': {'alpha': aB.squeeze(-1), 'beta': bB.squeeze(-1), 'p': pB.squeeze(-1), 'S': SB.squeeze(-1)},
            'I': {'alpha': aI.squeeze(-1), 'beta': bI.squeeze(-1), 'p': pI.squeeze(-1), 'S': SI.squeeze(-1)},
        }
        return out

    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'evidential_lambda': self.evidential_lambda
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path, map_location=None):
        checkpoint = torch.load(path, map_location=map_location)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.evidential_lambda = checkpoint.get('evidential_lambda', self.evidential_lambda)
        print(f"Model loaded from {path}")






class LightweightTransformerABIModel(nn.Module):
    def __init__(self, state_dim, seq_len=15, hidden_dim=32, nhead=2, num_layers=1):
        super().__init__()
        self.input_proj = nn.Linear(state_dim, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim * 2,
            batch_first=True,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Attention pooling
        self.attention_vector = nn.Parameter(torch.randn(hidden_dim))

        self.fc_ability = nn.Linear(hidden_dim, 1)
        self.fc_benevolence = nn.Linear(hidden_dim, 1)
        self.fc_integrity = nn.Linear(hidden_dim, 1)

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.data_buffer = []

    def forward(self, x):
        # x: [batch, seq_len, state_dim]
        x = self.input_proj(x)           # [B, T, H]
        h = self.transformer(x)          # [B, T, H]

        # Attention pooling
        attn_weights = torch.softmax(torch.matmul(h, self.attention_vector), dim=1)  # [B, T]
        pooled = torch.sum(h * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        ability     = torch.sigmoid(self.fc_ability(pooled))
        benevolence = torch.sigmoid(self.fc_benevolence(pooled))
        integrity   = torch.sigmoid(self.fc_integrity(pooled))
        return ability, benevolence, integrity

    def store_data(self, state_history, ability, benevolence, integrity):
        self.data_buffer.append((state_history, ability, benevolence, integrity))


    def train_step(self):
        # print(self.data_buffer)
        if not self.data_buffer:
            # print('here')
            return None

        states, abilitys, benevolences, integritys = zip(*self.data_buffer)
        states = torch.tensor(states, dtype=torch.float32)
        abilitys = torch.tensor(abilitys, dtype=torch.float32).unsqueeze(-1)
        benevolences = torch.tensor(benevolences, dtype=torch.float32).unsqueeze(-1)
        integritys = torch.tensor(integritys, dtype=torch.float32).unsqueeze(-1)

        self.optimizer.zero_grad()
        pred_ability, pred_benevolence, pred_integrity = self.forward(states)

        loss_ability = self.criterion(pred_ability, abilitys)
        loss_benevolence = self.criterion(pred_benevolence, benevolences)
        loss_integrity = self.criterion(pred_integrity, integritys)

        total_loss = loss_ability + loss_benevolence + loss_integrity
        total_loss.backward()
        self.optimizer.step()

        return {
            'total_loss': total_loss.item(),
            'loss_ability': loss_ability.item(),
            'loss_benevolence': loss_benevolence.item(),
            'loss_integrity': loss_integrity.item()
        }



    def save_model(self, path):
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {path}")








class EpisodeRewardCallback(BaseCallback):
    def __init__(self, save_path, save_freq=100000, verbose=0):
        super(EpisodeRewardCallback, self).__init__(verbose)
        self.save_path = save_path
        self.save_freq = save_freq
        self.step_counter = 0
        self.episode_rewards = []
        self.current_episode_reward = 0.0

        os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self) -> bool:
        self.step_counter += 1
        self.current_episode_reward += self.locals['rewards'][0]

        # Check if episode is done
        if self.locals['dones'][0]:
            self.episode_rewards.append(self.current_episode_reward)
            self.current_episode_reward = 0.0

        # Save and plot every save_freq steps
        if self.step_counter % self.save_freq == 0:
            model_path = os.path.join(self.save_path, f'model_{self.step_counter}.zip')
            self.model.save(model_path)
            print(f"Step {self.step_counter}: Model saved at {model_path}")

            # Plot moving average of last 100 episode rewards
            if len(self.episode_rewards) >= 1:
                window = min(100, len(self.episode_rewards))
                moving_avg = [sum(self.episode_rewards[max(0, i - window + 1):i + 1]) / (i - max(0, i - window + 1) + 1)
                              for i in range(len(self.episode_rewards))]

                plt.figure(figsize=(10, 5))
                plt.plot(moving_avg, label=f"Moving Avg (last {window} episodes)")
                plt.xlabel("Episode")
                plt.ylabel("Average Reward")
                plt.title("Training Progress")
                plt.legend()
                plt.grid()
                plt.tight_layout()
                plt.savefig(os.path.join(self.save_path, f'avg_reward_{self.step_counter}.png'))
                plt.close()

        return True








class SingleAgentWrapper_Without_Latent(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, agent_index, other_agent_model=None):
        super(SingleAgentWrapper_Without_Latent, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.other_agent_model = other_agent_model
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


    def reset(self):
        self.obs = self.env.reset()

        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        return self.obs[self.agent_index]

    def step(self, action):

        benevolence_reward = 0
        
        actions = [0, 0]

        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])

        actions[self.agent_index] = action

        actions[1 - self.agent_index] = other_agent_action[0]

        if self.agent_index == 1:
            action_right = action
            action_left = other_agent_action[0]

        if self.agent_index == 0:
            action_left = action
            action_right = other_agent_action[0]



        primary_actions = self.env._computeLowLevelActions(actions, 0)

        self.obs, rewards, dones, info = self.env.step(primary_actions)

        self.obs = self.env._get_macro_obs()

        
        if self.agent_index == 0:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info
        if self.agent_index == 1:
            return self.obs[self.agent_index], rewards[self.agent_index] + benevolence_reward, dones, info





def get_boltzmann_probs_from_policy(model, obs, beta: float = 1.0):

    obs_t, _ = model.policy.obs_to_tensor(obs)
    with torch.no_grad():
        dist = model.policy.get_distribution(obs_t)     
        logits = dist.distribution.logits               
        scaled_logits = logits * beta                   
        probs = torch.softmax(scaled_logits, dim=-1).cpu().numpy()[0]

    return probs

def sample_action_from_probs(probs: np.ndarray):

    return int(np.random.choice(len(probs), p=probs))




class SingleAgentWrapper(gym.Wrapper):
    """
    A wrapper to extract a single agent's perspective from a multi-agent environment.
    """
    def __init__(self, env, partner_env, agent_index, state_dim, model_path=None):
        super(SingleAgentWrapper, self).__init__(env)
        self.agent_index = agent_index
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        
        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True

        self.obs = None


        self.partner_position_history = []



        self.partner_ability = None
        self.partner_benevolence = None
        self.partner_integrity = None

        self.predicted_ability = None
        self.predicted_benevolence = None
        self.predicted_integrity = None


        self.partner_A_alpha = 1
        self.partner_A_beta = 1

        self.partner_B_alpha = 1
        self.partner_B_beta = 1

        self.partner_I_alpha = 1
        self.partner_I_beta = 1


        agent_lists = [
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_lowI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_lowI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_lowB_highI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_lowB_lowI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_lowI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_lowI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_lowB_highI_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_lowB_lowI_agent1/model_4100000",
            

            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version1_agent1/model_4100000",
            "final_trained_models/[MapB]trustee_agent_highB_highI_version2_agent1/model_4100000",
        ]



        self.models = {}
        for i, path in enumerate(agent_lists):
            model_name = f"model{i}"
            self.models[model_name] = PPO.load(path, env=partner_env)
            print(f"Loaded {model_name} from {path}")

        print(self.models)

        self.model = LightweightTransformerABIModel_adaptivelength(state_dim)
        if model_path:
            self.model.load_model(model_path)

        self.state_history = deque(maxlen=30)

        self.partner_key_event_happen = False

        """增加obs维度"""
        # Assuming the original observation space is a Box space, extend it by 2 dimensions
        if isinstance(env.observation_space, gym.spaces.Box):
            num_extra_dims = 6
            low = np.append(env.observation_space.low, [-np.inf] * num_extra_dims)
            high = np.append(env.observation_space.high, [np.inf] * num_extra_dims)
            # Create the new observation space with the extended bounds
            self.observation_space = gym.spaces.Box(low=low, high=high, dtype=env.observation_space.dtype)
        else:
            raise NotImplementedError("This wrapper only works with Box observation spaces.")
        



    def reset(self):
        self.obs = self.env.reset()

        self.partner_position_history = []

        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])


        self.firsttime_left_go_to_counter = True
        self.firsttime_left_get_lettuce = True
        self.firsttime_right_go_to_knife = True
        self.firsttime_right_go_to_counter = True


        self.partner_A_alpha = 1
        self.partner_A_beta = 1

        self.partner_B_alpha = 1
        self.partner_B_beta = 1

        self.partner_I_alpha = 1
        self.partner_I_beta = 1



        self.state_history.clear()

        partner_cented_obs = self.env._get_macro_vector_obs_for_ABImodel()

        self.state_history.append(partner_cented_obs[1])

        self.partner_key_event_happen = False


        random.seed()
        partner_idx = random.randint(0, len(self.models)-1)

        self.other_agent_model = self.models["model"+str(partner_idx)]





        if partner_idx == 0:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 1:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1


        if partner_idx == 2:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = -1

        if partner_idx == 3:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = -1


        if partner_idx == 4:
            self.partner_ability = 1
            self.partner_benevolence = -1
            self.partner_integrity = 1


        if partner_idx == 5:
            self.partner_ability = 1
            self.partner_benevolence = -1
            self.partner_integrity = -1



        if partner_idx == 6:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = 1

        if partner_idx == 7:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = 1


        if partner_idx == 8:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = -1

        if partner_idx == 9:
            self.partner_ability = -1
            self.partner_benevolence = 1
            self.partner_integrity = -1


        if partner_idx == 10:
            self.partner_ability = -1
            self.partner_benevolence = -1
            self.partner_integrity = 1


        if partner_idx == 11:
            self.partner_ability = -1
            self.partner_benevolence = -1
            self.partner_integrity = -1


        if partner_idx > 11:
            self.partner_ability = 1
            self.partner_benevolence = 1
            self.partner_integrity = 1




        self.predicted_ability = self.partner_ability
        self.predicted_benevolence = self.partner_benevolence
        self.predicted_integrity = self.partner_integrity

        self.partner_A = float(self.partner_A_alpha / (self.partner_A_alpha + self.partner_A_beta))
        self.partner_B = float(self.partner_B_alpha / (self.partner_B_alpha + self.partner_B_beta))
        self.partner_I = float(self.partner_I_alpha / (self.partner_I_alpha + self.partner_I_beta))


        self.partner_A = 0.5
        self.partner_B = 0.5
        self.partner_I = 0.5

        self.partner_A_confidence = 0
        self.partner_B_confidence = 0
        self.partner_I_confidence = 0


        self.partner_A_SD = float(math.sqrt((self.partner_A_alpha * self.partner_A_beta) / ((self.partner_A_alpha + self.partner_A_beta)**2 * (self.partner_A_alpha + self.partner_A_beta + 1))))
        self.partner_B_SD = float(math.sqrt((self.partner_B_alpha * self.partner_B_beta) / ((self.partner_B_alpha + self.partner_B_beta)**2 * (self.partner_B_alpha + self.partner_B_beta + 1))))
        self.partner_I_SD = float(math.sqrt((self.partner_I_alpha * self.partner_I_beta) / ((self.partner_I_alpha + self.partner_I_beta)**2 * (self.partner_I_alpha + self.partner_I_beta + 1))))


        return np.concatenate([
            self.obs[self.agent_index],
            [self.partner_A, self.partner_B, self.partner_I, self.partner_A_confidence, self.partner_B_confidence, self.partner_I_confidence]
        ])
    




    def step(self, action):

        benevolence_reward = 0

        benevolence_reward_left = 0
        
        actions = [0, 0]

        boltzmann_actions = [0, 0]


        other_agent_action = self.other_agent_model.predict(self.obs[1 - self.agent_index])


        actions[self.agent_index] = action
        actions[1 - self.agent_index] = other_agent_action[0]

        if self.partner_ability == -1 and self.env.macroAgent[1].cur_macro_action_done == True:
            # print('更新概率了')
            self.action_probabilities = get_boltzmann_probs_from_policy(self.other_agent_model, self.obs[1 - self.agent_index], beta = 0.3)





        primary_actions, real_execute_macro_actions = self.env._computeLowLevelActions(actions)

        low_level_action_to_execute = primary_actions

        if self.partner_ability == -1 and self.env.macroAgent[1].cur_macro_action_done != True:
            other_agent_action_boltzmaan = sample_action_from_probs(self.action_probabilities)

            boltzmann_actions[self.agent_index] = action
            boltzmann_actions[1 - self.agent_index] = other_agent_action_boltzmaan


            primary_actions_boltzmaan, real_execute_macro_actions = self.env._computeLowLevelActions_boltzmaanlowlevel(boltzmann_actions)

            low_level_action_to_execute = [primary_actions[0], primary_actions_boltzmaan[1]]



        previous_partner_holding_status = self.env.agent[1].holding

        if len(self.partner_position_history) > 5 and self.partner_position_history[-5] == self.partner_position_history[-4] and self.partner_position_history[-4] == self.partner_position_history[-3] and self.partner_position_history[-3] == self.partner_position_history[-2] and self.partner_position_history[-2] == self.partner_position_history[-1]:
            
            to_choose = []

            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y + 1]] == "space":
                to_choose.append(0)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x + 1][self.env.agent[1].y]] == "space":
                to_choose.append(1)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x][self.env.agent[1].y - 1]] == "space":
                to_choose.append(2)
            if ITEMNAME[self.env.agent[1].pomap[self.env.agent[1].x - 1][self.env.agent[1].y]] == "space":
                to_choose.append(3)

            low_level_action_to_execute[1 - self.agent_index] = random.choice(to_choose)


        self.obs, rewards, dones, info = self.env.step(low_level_action_to_execute)


        self.obs = self.env._get_macro_obs()





        self.partner_position_history.append([self.env.agent[1].x, self.env.agent[1].y])



        after_partner_holding_status = self.env.agent[1].holding

        partner_cented_obs = self.env._get_macro_vector_obs_for_ABImodel()

        self.state_history.append(partner_cented_obs[1])

        if previous_partner_holding_status is not None and after_partner_holding_status is None:
            self.partner_key_event_happen = True
        else:
            self.partner_key_event_happen = False




        if len(self.state_history) == 30 and self.partner_key_event_happen == True:


            state_np = np.asarray(self.state_history, dtype=np.float32)
            state_tensor = torch.from_numpy(state_np).unsqueeze(0)  # [1, T, D]

            device = next(self.model.parameters()).device
            state_tensor = state_tensor.to(device)

            T_cur = state_np.shape[0]

            if not hasattr(self, 'posterior_decay'):  self.posterior_decay = 0.999
            if not hasattr(self, 'kappa_max'):        self.kappa_max = 2.0
            if not hasattr(self, 'conf_mode'):        self.conf_mode = "var"
            if not hasattr(self, 'S_max'):            self.S_max = 50.0

            def _ensure_beta(name):
                a = f'partner_{name}_alpha'
                b = f'partner_{name}_beta'
                if not hasattr(self, a): setattr(self, a, 1.0)
                if not hasattr(self, b): setattr(self, b, 1.0)

            _ensure_beta('A'); _ensure_beta('B'); _ensure_beta('I')

            def _conf_from_ab(a: float, b: float):
                S = a + b
                if S <= 0.0:
                    mean = 0.5; var = 1.0/12.0
                else:
                    mean = a / S
                    var  = (a * b) / (S * S * (S + 1.0))
                std = math.sqrt(var)
                conf_var = max(0.0, min(1.0, 1.0 - 2.0 * std))  # 1 - Std/0.5
                conf_S   = S / (S + self.S_max)
                if self.conf_mode == "var":
                    conf = conf_var
                elif self.conf_mode == "strength":
                    conf = conf_S
                else:  # "hybrid"
                    conf = 0.5 * conf_var + 0.5 * conf_S
                return mean, var, std, conf, S

            with torch.no_grad():
                if hasattr(self.model, "predict_beta_params"):
                    out = self.model.predict_beta_params(state_tensor, lengths_np=[T_cur])
                    aA = float(out['A']['alpha']); bA = float(out['A']['beta'])
                    aB = float(out['B']['alpha']); bB = float(out['B']['beta'])
                    aI = float(out['I']['alpha']); bI = float(out['I']['beta'])
                else:
                    (aA_t,bA_t),(aB_t,bB_t),(aI_t,bI_t) = self.model(state_tensor, lengths=torch.tensor([T_cur]))
                    aA, bA = float(aA_t.item()), float(bA_t.item())
                    aB, bB = float(aB_t.item()), float(bB_t.item())
                    aI, bI = float(aI_t.item()), float(bI_t.item())

            meanA, varA, stdA, confA, SA = _conf_from_ab(aA, bA)
            meanB, varB, stdB, confB, SB = _conf_from_ab(aB, bB)
            meanI, varI, stdI, confI, SI = _conf_from_ab(aI, bI)

            self.partner_A_prob = float(meanA)
            self.partner_B_prob = float(meanB)
            self.partner_I_prob = float(meanI)

            self.partner_A_SD   = float(stdA)
            self.partner_B_SD   = float(stdB)
            self.partner_I_SD   = float(stdI)

            self.partner_A_conf = float(confA)
            self.partner_B_conf = float(confB)
            self.partner_I_conf = float(confI)



            def _update_posterior_for_smoothing(name, p, S_model):
                a_name = f'partner_{name}_alpha'
                b_name = f'partner_{name}_beta'
                setattr(self, a_name, getattr(self, a_name) * self.posterior_decay)
                setattr(self, b_name, getattr(self, b_name) * self.posterior_decay)
                kappa = min(S_model, self.kappa_max)
                setattr(self, a_name, getattr(self, a_name) + kappa * p)
                setattr(self, b_name, getattr(self, b_name) + kappa * (1.0 - p))
                A = getattr(self, a_name); B = getattr(self, b_name)
                S_post = A + B
                p_smooth = A / S_post if S_post > 0 else 0.5
                setattr(self, f'partner_{name}_prob_smooth', float(p_smooth))

            _update_posterior_for_smoothing('A', meanA, SA)
            _update_posterior_for_smoothing('B', meanB, SB)
            _update_posterior_for_smoothing('I', meanI, SI)


            self.partner_A = 1 if self.partner_A_prob_smooth >= 0.5 else -1
            self.partner_B = 1 if self.partner_B_prob_smooth >= 0.5 else -1
            self.partner_I = 1 if self.partner_I_prob_smooth >= 0.5 else -1

            self.partner_A_confidence = self.partner_A_conf
            self.partner_B_confidence = self.partner_B_conf
            self.partner_I_confidence = self.partner_I_conf


            self.model.store_data(list(self.state_history), 
                                self.partner_ability,
                                self.partner_benevolence,
                                self.partner_integrity)
       

        
        return np.concatenate([
            self.obs[self.agent_index],
            [self.partner_A, self.partner_B, self.partner_I, self.partner_A_confidence, self.partner_B_confidence, self.partner_I_confidence]
        ]), rewards[self.agent_index]+rewards[1-self.agent_index] + benevolence_reward_left, dones, info
    




ITEMNAME = ["space", "counter", "agent", "tomato", "lettuce", "plate", "knife", "delivery", "onion", "dirtyplate", "badlettuce"]

macroActionDict = {"stay": 0, "get lettuce 1": 1, "get lettuce 2": 2, "get badlettuce": 3, "get plate 1": 4, "get plate 2": 5, "go to knife 1": 6, "go to knife 2": 7, "deliver 1": 8, "deliver 2": 9, "chop": 10, "go to counter": 11, "right": 12, "down": 13, "left": 14, "up": 15}


mac_env_id = 'Overcooked-MA-v1'
rewardList = [{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": -20,
    "pick up bad lettuce": -100
},{
    "minitask finished": 0,
    "minitask failed": 0,
    "metatask finished": 0,
    "metatask failed": 0,
    "goodtask finished": 10,
    "goodtask failed": 0,
    "subtask finished": 20,
    "subtask failed": 0,
    "correct delivery": 200,
    "wrong delivery": -50,
    "step penalty": -1,
    "penalize using dirty plate": 0,
    "penalize using bad lettuce": 0,
    "pick up bad lettuce": 0
}]

env_params = {
    'grid_dim': [15, 15],
    'task': ["lettuce salad"],
    'rewardList': rewardList,
    'map_type': "B",
    'n_agent': 2,
    'obs_radius': 0,
    'mode': "vector",
    'debug': False,
}


def make_env(rank: int, abi_model_path: str):
    """
    base env -> partner wrapper (agent_index=1) -> training wrapper (agent_index=0)
    """
    def _thunk():
        # base env
        base = gym.make(mac_env_id, **env_params)

        # partner env
        partner_env = SingleAgentWrapper_Without_Latent(base, agent_index=1, other_agent_model=None)

        wrapped = SingleAgentWrapper(
            base,                      
            partner_env,               
            agent_index=0,
            state_dim=6,               
            model_path=abi_model_path  
        )

        base.seed(SEED + rank)
        np.random.seed(SEED + rank)
        random.seed(SEED + rank)

        return wrapped
    return _thunk


from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
policy_kwargs = dict(
    features_extractor_class=ABIGatedExtractorWithConf,
    features_extractor_kwargs=dict(base_dim=64),
    net_arch=[dict(pi=[128, 64], vf=[128, 64])]
)



ppo_params = {
    'learning_rate': 3e-4,
    'n_steps': 3600,
    'batch_size': 600,
    'n_epochs': 4,
    'gamma': 0.95,
    'gae_lambda': 0.95,
    'clip_range': 0.3,
    'ent_coef': 0.02,
    'vf_coef': 0.5,
    'max_grad_norm': 0.5,
    'verbose': 0,
}


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)

    n_envs = 8
    abi_model_path = "[MapB]abi_model_outputs_Boltzmaan_10times_adaptivelength_FFFFFFFFFFFFFFINAL_with_uncertainty/abi_model_step_10000.pt"

    venv = SubprocVecEnv([make_env(i, abi_model_path) for i in range(n_envs)])

    model = PPO("MlpPolicy", venv, policy_kwargs=policy_kwargs, **ppo_params)

    callback = EpisodeRewardCallback(
        save_path="final_trained_models/[MapB]TrustPOMDP",
        save_freq=100_000
    )

    def format_time(s): 
        m=int(s//60); ss=int(s%60); return f"{m}minite{ss}second"

    import time
    start = time.time()
    total_loops = 300
    steps_per_loop = 100_0000

    for i in range(1, total_loops + 1):
        model.learn(total_timesteps=steps_per_loop, callback=callback, reset_num_timesteps=False)
        print(f"[Loop {i}/{total_loops}] Trained {i*steps_per_loop:,} steps | time spent {format_time(time.time()-start)}")

