import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple


def masked_softmax(
    logits: torch.Tensor,
    mask: torch.Tensor,
    dim: int = -1,
    temp: float = 1.0,
    eps: float = 1e-12
) -> torch.Tensor:
    mask = mask.bool()
    masked_logits = logits.masked_fill(~mask, -1e9)
    probs = F.softmax(masked_logits / temp, dim=dim)

    probs = probs * mask.to(probs.dtype)
    denom = probs.sum(dim=dim, keepdim=True) 
    probs = torch.where(denom > 0, probs / denom.clamp_min(eps), torch.zeros_like(probs))
    return probs


class BigMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_blocks=2, use_ln=False, dropout=0.0):
        super().__init__()
        layers = []
        d = in_dim
        for _ in range(max(0, n_blocks - 1)):
            layers += [nn.Linear(d, hidden_dim), nn.ReLU(inplace=True)]
            if dropout > 0:
                layers += [nn.Dropout(dropout)]
            d = hidden_dim
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)
        self.ln = nn.LayerNorm(out_dim) if use_ln else nn.Identity()

    def forward(self, x):
        return self.ln(self.net(x))


class MI_Action(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.N_max = args.n_agents
        self.n_actions = args.n_actions
        self.hist_dim = args.rnn_hidden_dim

        wm = args.width_mult  
        D_hist = max(32, int(128 * wm))
        H_ctx  = max(64, int(256 * wm))

        D_act0 = max(16, int(64 * wm))
        D_act  = max(16, int(64 * wm))

        self.D_hist = D_hist
        self.H_ctx  = H_ctx
        self.D_act0 = D_act0
        self.D_act  = D_act

        self.hist_embed_joint = BigMLP(
            self.hist_dim * self.N_max,
            hidden_dim=H_ctx,
            out_dim=D_hist,
            n_blocks=2,
            use_ln=False
        )

        self.act_emb = nn.Embedding(self.n_actions, D_act0)
        self.act_pool_mlp = BigMLP(D_act0, hidden_dim=H_ctx, out_dim=D_act, n_blocks=2, use_ln=False)

        self.G1_pool_mlp_phi = BigMLP(D_hist + D_act, hidden_dim=H_ctx, out_dim=H_ctx, n_blocks=2, use_ln=False)
        self.G1_pool_mlp_psi = BigMLP(D_hist,         hidden_dim=H_ctx, out_dim=H_ctx, n_blocks=2, use_ln=False)

        self.phi_head_joint = BigMLP(H_ctx + D_hist, hidden_dim=H_ctx, out_dim=self.N_max * self.n_actions, n_blocks=2, use_ln=False)
        self.psi_head_joint = BigMLP(H_ctx + D_hist, hidden_dim=H_ctx, out_dim=self.N_max * self.n_actions, n_blocks=2, use_ln=False)

    def _pad_to_Nmax(self, x: torch.Tensor, pad_dims: int) -> torch.Tensor:
        B, N = x.shape[:2]
        if N >= self.N_max:
            return x
        if pad_dims == 1:
            pad = x.new_zeros(B, self.N_max - N)
        else:
            pad = x.new_zeros(B, self.N_max - N, x.size(-1))
        return torch.cat([x, pad], dim=1)

    def flat_mask(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        B, N, D = x.shape
        x = self._pad_to_Nmax(x, pad_dims=2)
        m = self._pad_to_Nmax(mask, pad_dims=1).float()
        x = x * m.unsqueeze(-1)
        return x.reshape(B, self.N_max * D)

    def nll_loss_fast(self, probs: torch.Tensor, targets_all: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
        if probs.numel() == 0:
            return probs.new_tensor(0.0)
        eps = 1e-12
        B, Nmax, A = probs.shape
        targ = targets_all.long().clamp_min(0).clamp_max(A - 1)
        p_t = probs.gather(dim=-1, index=targ.unsqueeze(-1)).squeeze(-1)  
        nll = -(p_t.clamp_min(eps).log()) * valid_mask.float()
        denom = valid_mask.float().sum().clamp_min(1.0)
        return nll.sum() / denom

    def categorical_kl(self, p: torch.Tensor, q: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
        if p.numel() == 0:
            return p.new_tensor(0.0)
        eps = 1e-12
        logp = (p + eps).log()
        logq = (q + eps).log()
        kl_per_agent = (p * (logp - logq)).sum(dim=-1) 
        kl_per_agent = kl_per_agent * valid_mask.float()
        denom = valid_mask.float().sum().clamp_min(1.0)
        return kl_per_agent.sum() / denom

    def _g1_action_pooled(
        self,
        actions: torch.Tensor,  
        g1_mask: torch.Tensor  
    ) -> torch.Tensor:
        A = self.n_actions
        acts_pad = self._pad_to_Nmax(actions, pad_dims=1).long().clamp(0, A - 1)  
        emb = self.act_emb(acts_pad)                                             

        m = g1_mask.float().unsqueeze(-1)                                        
        denom = g1_mask.float().sum(dim=1, keepdim=True).clamp_min(1.0)          
        pooled = (emb * m).sum(dim=1) / denom                                    
        return pooled

    def forward_distributions(
        self,
        history: torch.Tensor,             
        g1_index_mask: torch.Tensor,        
        g2_index_mask: torch.Tensor,         
        actions_g1: Optional[torch.Tensor] = None,       
        g2_action_mask: Optional[torch.Tensor] = None,    
        aG1_pooled: Optional[torch.Tensor] = None,         
    ) -> Dict[str, torch.Tensor]:

        device = history.device
        B, N, H = history.shape
        A = self.n_actions
        temp = float(getattr(self.args, "temp", 1.0))

        g1m = self._pad_to_Nmax(g1_index_mask, pad_dims=1)  
        g2m = self._pad_to_Nmax(g2_index_mask, pad_dims=1)  

        hG1_flat = self.flat_mask(history, g1m)            
        hG2_flat = self.flat_mask(history, g2m)             
        hG1_e = self.hist_embed_joint(hG1_flat)           
        hG2_e = self.hist_embed_joint(hG2_flat)             

        if aG1_pooled is None:
            assert actions_g1 is not None
            aG1_pooled = self._g1_action_pooled(actions_g1, g1m)  
        aG1_e = self.act_pool_mlp(aG1_pooled)                    

        phi_ctx = self.G1_pool_mlp_phi(torch.cat([hG1_e, aG1_e], dim=-1))  
        psi_ctx = self.G1_pool_mlp_psi(hG1_e)                            

        phi_logits = self.phi_head_joint(torch.cat([phi_ctx, hG2_e], dim=-1)).view(B, self.N_max, A)
        psi_logits = self.psi_head_joint(torch.cat([psi_ctx, hG2_e], dim=-1)).view(B, self.N_max, A)

        if g2_action_mask is None:
            avail = torch.ones(B, N, A, device=device, dtype=torch.bool)
        else:
            avail = (g2_action_mask > 0)

        avail = self._pad_to_Nmax(avail, pad_dims=2)     
        g2_blocks = g2m.unsqueeze(-1).bool()            
        full_mask = avail & g2_blocks                    

        phi_probs = masked_softmax(phi_logits, full_mask, dim=-1, temp=temp)
        psi_probs = masked_softmax(psi_logits, full_mask, dim=-1, temp=temp)

        valid_agent_mask = g2m.bool()               
        return dict(phi_probs=phi_probs, psi_probs=psi_probs, valid_agent_mask=valid_agent_mask, full_mask=full_mask)

    def select_min_mi_actions_per_agent(
        self,
        history: torch.Tensor,              
        g1_index_mask: torch.Tensor,        
        g2_index_mask: torch.Tensor,       
        base_actions_all: torch.Tensor,      
        avail_actions: torch.Tensor,         
        g2_action_mask: Optional[torch.Tensor] = None,    
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        device = history.device
        B, N, H = history.shape
        A = self.n_actions
        temp = float(getattr(self.args, "temp", 1.0))
        eps = 1e-12

        best_actions_all = base_actions_all.clone()
        kl_values = history.new_zeros(B, N)

        with torch.no_grad():
            for b in range(B):
                g1m_raw = g1_index_mask[b].bool()  
                g2m_raw = g2_index_mask[b].bool()

                g1_idxs = torch.nonzero(g1m_raw, as_tuple=False).squeeze(-1)
                if g1_idxs.numel() == 0:
                    continue

                g1m = self._pad_to_Nmax(g1_index_mask[b:b+1], pad_dims=1).squeeze(0).bool() 
                g2m = self._pad_to_Nmax(g2_index_mask[b:b+1], pad_dims=1).squeeze(0).bool() 

                hG1_flat = self.flat_mask(history[b:b+1], g1m.unsqueeze(0))  
                hG2_flat = self.flat_mask(history[b:b+1], g2m.unsqueeze(0)) 
                hG1_e = self.hist_embed_joint(hG1_flat)                     
                hG2_e = self.hist_embed_joint(hG2_flat)                     

                psi_ctx = self.G1_pool_mlp_psi(hG1_e)                       
                psi_logits = self.psi_head_joint(torch.cat([psi_ctx, hG2_e], dim=-1)).view(1, self.N_max, A)

                if g2_action_mask is None:
                    avail = torch.ones(1, N, A, device=device, dtype=torch.bool)
                else:
                    avail = (g2_action_mask[b:b+1] > 0)
                avail = self._pad_to_Nmax(avail, pad_dims=2)                 
                full_mask = avail & g2m.unsqueeze(-1).unsqueeze(0)        

                psi_probs = masked_softmax(psi_logits, full_mask, dim=-1, temp=temp)  
                logq = (psi_probs + eps).log()                           
                valid_agent_mask = g2m.unsqueeze(0)                        

                acts_pad = self._pad_to_Nmax(best_actions_all[b:b+1], pad_dims=1).squeeze(0).long().clamp(0, A - 1)  
                emb_all = self.act_emb(acts_pad)                        

                denom = g1m.float().sum().clamp_min(1.0)                    
                pooled_base = (emb_all * g1m.float().unsqueeze(-1)).sum(dim=0) / denom  

                pairs_i = []
                pairs_a = []
                for i in g1_idxs.tolist():
                    legal = torch.nonzero(avail_actions[b, i] > 0, as_tuple=False).squeeze(-1)
                    if legal.numel() == 0:
                        continue
                    pairs_i.append(torch.full((legal.numel(),), i, device=device, dtype=torch.long))
                    pairs_a.append(legal.to(torch.long))

                if len(pairs_i) == 0:
                    continue

                pairs_i = torch.cat(pairs_i, dim=0)  
                pairs_a = torch.cat(pairs_a, dim=0)  
                C = pairs_i.numel()

                emb_old = emb_all[pairs_i]                 
                emb_new = self.act_emb(pairs_a.clamp(0, A - 1))  
                pooled_c = pooled_base.unsqueeze(0).expand(C, -1) + (emb_new - emb_old) / denom 

                aG1_e = self.act_pool_mlp(pooled_c) 
                hG1_eC = hG1_e.expand(C, -1)
                hG2_eC = hG2_e.expand(C, -1)

                phi_ctx = self.G1_pool_mlp_phi(torch.cat([hG1_eC, aG1_e], dim=-1)) 
                phi_logits = self.phi_head_joint(torch.cat([phi_ctx, hG2_eC], dim=-1)).view(C, self.N_max, A)

                phi_probs = masked_softmax(phi_logits, full_mask.expand(C, -1, -1), dim=-1, temp=temp) 

                logp = (phi_probs + eps).log()
                kl_per_agent = (phi_probs * (logp - logq.expand(C, -1, -1))).sum(dim=-1)  
                valid = valid_agent_mask.expand(C, -1).float()                            
                kl_each = (kl_per_agent * valid).sum(dim=1) / valid.sum(dim=1).clamp_min(1.0) 

                best_act = best_actions_all[b].clone()
                uniq_i = torch.unique(pairs_i)
                for i in uniq_i.tolist():
                    idx = (pairs_i == i).nonzero(as_tuple=False).squeeze(-1)
                    if idx.numel() == 0:
                        continue
                    k = idx[torch.argmin(kl_each[idx])]
                    best_act[i] = pairs_a[k]
                    kl_values[b, i] = kl_each[k]
                best_actions_all[b] = best_act

        return best_actions_all, kl_values

    def compute_losses(
        self,
        history: torch.Tensor,               
        g1_index_mask: torch.Tensor,         
        g2_index_mask: torch.Tensor,        
        actions_all: torch.Tensor,           
        g2_action_mask: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:

        outs = self.forward_distributions(
            history=history,
            g1_index_mask=g1_index_mask,
            g2_index_mask=g2_index_mask,
            actions_g1=actions_all,          
            g2_action_mask=g2_action_mask
        )

        phi_probs = outs["phi_probs"]                    
        psi_probs = outs["psi_probs"]                     
        valid = outs["valid_agent_mask"]                 

        targets_pad = self._pad_to_Nmax(actions_all, pad_dims=1)  
        loss_phi = self.nll_loss_fast(phi_probs, targets_pad, valid)
        loss_psi = self.nll_loss_fast(psi_probs, targets_pad, valid)
        kl = self.categorical_kl(phi_probs, psi_probs, valid)

        return dict(loss_phi=loss_phi, loss_psi=loss_psi, kl=kl)
    

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple

class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden: int, out_dim: int, n_layers: int = 2):
        super().__init__()
        layers = []
        d = in_dim
        for _ in range(n_layers - 1):
            layers += [nn.Linear(d, hidden), nn.ReLU(inplace=False)]
            d = hidden
        layers += [nn.Linear(d, out_dim)]
        self.net = nn.Sequential(*layers)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)


class MI_Obs(nn.Module):
    def __init__(
        self,
        args,
        hidden: int = 256,
        act_emb_dim: int = 32,
        rank: int = 64,              
        ema_alpha: float = 0.05,    
        learn_logv: bool = False,    
    ):
        super().__init__()
        self.args = args
        self.n_agents  = int(args.n_agents)
        self.n_actions = int(args.n_actions)
        self.h_dim     = int(args.rnn_hidden_dim)
        self.obs_dim   = int(args.obs_shape)

        self.act_emb = nn.Embedding(self.n_actions, act_emb_dim)

        ctx_in = self.h_dim + act_emb_dim
        self.ctx_enc = MLP(ctx_in, hidden=hidden, out_dim=hidden, n_layers=2)

        self.base_mu = nn.Linear(hidden, self.obs_dim)

        self.learn_logv = bool(learn_logv)
        if self.learn_logv:
            self.base_logv = nn.Linear(hidden, self.obs_dim)
        else:
            self.global_logv = nn.Parameter(torch.zeros(self.obs_dim))

        self.ctx_proj = nn.Linear(hidden, rank)
        self.act_proj = nn.Linear(act_emb_dim, rank)
        self.delta_out = nn.Linear(rank, self.obs_dim)

        self.ema_alpha = float(ema_alpha)
        self.register_buffer("score_ema", torch.zeros(self.n_agents, self.n_agents, self.obs_dim))
        self.register_buffer("score_cnt", torch.zeros(self.n_agents, self.n_agents, 1))  

    @staticmethod
    def gaussian_nll_per_dim(mu: torch.Tensor, logv: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        inv_var = torch.exp(-logv)
        return 0.5 * (math.log(2 * math.pi) + logv + (y - mu) * (y - mu) * inv_var)

    def _sample_j_idx_shared(self, J: int, device: torch.device) -> torch.Tensor:
        N = self.n_agents
        j_idx = torch.randint(0, N, (N, J), device=device)
        ar = torch.arange(N, device=device).unsqueeze(1)
        bad = (j_idx == ar)

        for _ in range(4):
            if not bad.any():
                break
            j_idx = torch.where(bad, torch.randint(0, N, (N, J), device=device), j_idx)
            bad = (j_idx == ar)

        if bad.any():
            j_idx = torch.where(bad, (j_idx + 1) % N, j_idx)
        return j_idx  

    def forward_all_pairs(
        self,
        tau: torch.Tensor,        
        actions: torch.Tensor,     
        pair_sample_j: Optional[int] = None,  
    ) -> Dict[str, torch.Tensor]:
        device = tau.device
        B, N, H = tau.shape
        assert N == self.n_agents and H == self.h_dim

        a_emb = self.act_emb(actions.clamp(0, self.n_actions - 1))         
        ctx = self.ctx_enc(torch.cat([tau, a_emb], dim=-1))               

        mu0 = self.base_mu(ctx)                                            
        if self.learn_logv:
            logv0 = self.base_logv(ctx).clamp(-8.0, 8.0)                
        else:
            logv0 = self.global_logv.view(1, 1, self.obs_dim).expand(B, N, self.obs_dim) 

        cproj = self.ctx_proj(ctx)                                          
        aproj = self.act_proj(a_emb)                                       

        if pair_sample_j is None:
            pair_feat = cproj.unsqueeze(2) * aproj.unsqueeze(1)            
            delta_mu = self.delta_out(pair_feat)                             
            mu_full = mu0.detach().unsqueeze(2) + delta_mu                
            logv_full = logv0.detach().unsqueeze(2)                         
            return dict(mu0=mu0, logv0=logv0, mu_full=mu_full, logv_full=logv_full, j_idx=None)

        J = int(pair_sample_j)
        j_idx = self._sample_j_idx_shared(J, device=device)  

        R = aproj.size(-1)

        aproj_src = aproj.unsqueeze(1).expand(B, N, N, R)     

        idx = j_idx.unsqueeze(0).expand(B, N, J)           
        idx_exp = idx.unsqueeze(-1).expand(B, N, J, R)       

        aproj_j = aproj_src.gather(dim=2, index=idx_exp)     

        pair_feat = cproj.unsqueeze(2) * aproj_j                           
        delta_mu = self.delta_out(pair_feat)                               
        mu_full = mu0.detach().unsqueeze(2) + delta_mu                      
        logv_full = logv0.detach().unsqueeze(2)                            
        return dict(mu0=mu0, logv0=logv0, mu_full=mu_full, logv_full=logv_full, j_idx=j_idx)

    def compute_losses_and_update_scores(
        self,
        tau: torch.Tensor,         
        actions: torch.Tensor,      
        next_obs: torch.Tensor,     
        pair_sample_j: Optional[int] = None,  
        w_base: float = 1.0,
        w_full: float = 1.0,
        clamp_score_min0: bool = True,
    ) -> Dict[str, torch.Tensor]:
        outs = self.forward_all_pairs(tau, actions, pair_sample_j=pair_sample_j)

        mu0, logv0 = outs["mu0"], outs["logv0"]                            
        y = next_obs                                                      

        nll_base_dim = self.gaussian_nll_per_dim(mu0, logv0, y)             
        loss_base = nll_base_dim.sum(dim=-1).mean()                         

        if outs["j_idx"] is None:
            mu_full, logv_full = outs["mu_full"], outs["logv_full"]         
            y_full = y.unsqueeze(2)                                        
            nll_full_dim = self.gaussian_nll_per_dim(mu_full, logv_full, y_full)  
            loss_full = nll_full_dim.sum(dim=-1).mean()

            score = (nll_base_dim.detach().unsqueeze(2) - nll_full_dim.detach()).mean(dim=0)
            if clamp_score_min0:
                score = score.clamp_min(0.0)

            diag = torch.arange(self.n_agents, device=score.device)
            score[diag, diag] = 0.0

            alpha = self.ema_alpha
            self.score_ema = (1 - alpha) * self.score_ema + alpha * score
            self.score_cnt = self.score_cnt + 1.0

        else:
            mu_full, logv_full = outs["mu_full"], outs["logv_full"]
            y_full = y.unsqueeze(2)                                        
            nll_full_dim = self.gaussian_nll_per_dim(mu_full, logv_full, y_full) 
            loss_full = nll_full_dim.sum(dim=-1).mean()

            score_nj = (nll_base_dim.detach().unsqueeze(2) - nll_full_dim.detach()).mean(dim=0)
            if clamp_score_min0:
                score_nj = score_nj.clamp_min(0.0)

            j_idx = outs["j_idx"]                                         
            alpha = self.ema_alpha
            Do = self.obs_dim

            idx_e = j_idx.unsqueeze(-1).expand(self.n_agents, j_idx.size(1), Do) 
            old = self.score_ema.gather(dim=1, index=idx_e)                      
            new = (1 - alpha) * old + alpha * score_nj
            self.score_ema.scatter_(dim=1, index=idx_e, src=new)
            self.score_cnt.scatter_(dim=1, index=idx_e[..., :1], src=self.score_cnt.gather(1, idx_e[..., :1]) + 1.0)

            diag = torch.arange(self.n_agents, device=self.score_ema.device)
            self.score_ema[diag, diag] = 0.0

        total = w_base * loss_base + w_full * loss_full
        return {
            "loss_total": total,
            "loss_base": loss_base,
            "loss_full": loss_full,
        }

    @torch.no_grad()
    def get_scores(self) -> torch.Tensor:
        return self.score_ema.clone()