# modules/opponent/contrastive_encoder.py
# Contains FiLM, ContextEncoder, ProtoQueue and InfoNCE loss.

import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

class FiLM(nn.Module):
    def __init__(self, in_dim: int, z_dim: int):
        super().__init__()
        self.gamma = nn.Linear(z_dim, in_dim)
        self.beta = nn.Linear(z_dim, in_dim)

    def forward(self, h: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        # h: [B*A, H] or [B, H]; z: [B, Z]
        if z.dim() == 2 and h.dim() == 2 and h.size(0) != z.size(0):
            reps = h.size(0) // z.size(0)
            z = z.unsqueeze(1).expand(-1, reps, -1).reshape(h.size(0), -1)
        gamma = self.gamma(z)
        beta = self.beta(z)
        return gamma * h + beta

class ContextEncoder(nn.Module):
    def __init__(self, in_dim: int, hid_dim: int = 64, z_dim: int = 16, num_layers: int = 1):
        super().__init__()
        self.gru = nn.GRU(input_size=in_dim, hidden_size=hid_dim, 
                         num_layers=num_layers, batch_first=True)
        self.proj = nn.Sequential(
            nn.Linear(hid_dim, z_dim),
            nn.Tanh()  # Add activation to bound outputs
        )

    def forward(self, ctx_seq: torch.Tensor, h0: torch.Tensor = None) -> torch.Tensor:
        # ctx_seq: [B, T_ctx, N, CTX]
        B, T_ctx, N, CTX = ctx_seq.size()
        
        # Process each agent's sequence independently
        # Reshape to [B*N, T_ctx, CTX]
        agent_sequences = ctx_seq.permute(0, 2, 1, 3)  # [B, N, T_ctx, CTX]
        agent_sequences = agent_sequences.reshape(B * N, T_ctx, CTX)
        
        # Process through GRU
        out, _ = self.gru(agent_sequences, h0)  # [B*N, T_ctx, hid_dim]
        
        # Take the last hidden state for each sequence
        last_hidden = out[:, -1, :]  # [B*N, hid_dim]
        
        # Project to z dimension
        z = self.proj(last_hidden)  # [B*N, z_dim]
        
        # Reshape back to [B, N, z_dim]
        return z.reshape(B, N, -1)

class ProtoQueue(nn.Module):
    """
    Momentum-averaged prototype memory. Use push() to record episode prototypes and sample_negatives()
    for InfoNCE negatives.
    """
    def __init__(self, z_dim: int, max_size: int = 4096, momentum: float = 0.9, device: str = "cpu"):
        super().__init__()
        self.register_buffer("protos", torch.zeros(max_size, z_dim))
        self.register_buffer("valid", torch.zeros(max_size, dtype=torch.bool))
        self.momentum = momentum
        self.max_size = max_size
        self.ptr = 0
        self.device_str = device

    @torch.no_grad()
    def _update_slot(self, idx: int, proto: torch.Tensor):
        if self.valid[idx]:
            self.protos[idx] = self.momentum * self.protos[idx] + (1 - self.momentum) * proto
        else:
            self.protos[idx] = proto
            self.valid[idx] = True

    @torch.no_grad()
    def push(self, episode_protos: torch.Tensor):
        # episode_protos: [B, z_dim]
        B = episode_protos.size(0)
        idxs = []
        for b in range(B):
            idx = self.ptr % self.max_size
            self._update_slot(idx, episode_protos[b].to(self.protos.device))
            idxs.append(idx)
            self.ptr += 1
        return torch.tensor(idxs, device=self.protos.device, dtype=torch.long), self.protos[idxs]

    def sample_negatives(self, exclude_indices: torch.Tensor, num_neg: int):
        valid_idxs = torch.nonzero(self.valid, as_tuple=False).squeeze(-1)
        if valid_idxs.numel() == 0:
            return None
        mask = torch.ones_like(self.valid, dtype=torch.bool)
        if exclude_indices is not None and exclude_indices.numel() > 0:
            mask[exclude_indices] = False
        pool = torch.nonzero(mask & self.valid, as_tuple=False).squeeze(-1)
        if pool.numel() == 0:
            return None
        # sample indices from pool with replacement
        choice = torch.randint(0, pool.numel(), (exclude_indices.size(0), num_neg), device=pool.device)
        neg_idxs = pool[choice]  # [B, num_neg]
        neg_protos = self.protos[neg_idxs]  # [B, num_neg, z_dim]
        return neg_protos

def info_nce_loss(z: torch.Tensor, pos_proto: torch.Tensor, neg_protos: torch.Tensor, temperature: float = 0.07) -> torch.Tensor:
    """
    z: [B, Z], pos_proto: [B, Z], neg_protos: [B, K, Z] or None
    """
    if not neg_protos is None:
        if torch.any(torch.isnan(z)) or torch.any(torch.isinf(z)):
            print('z is nan or inf')
        if torch.any(torch.isnan(pos_proto)) or torch.any(torch.isinf(pos_proto)):
            print('pos_proto is nan or inf')
        if torch.any(torch.isnan(neg_protos)) or torch.any(torch.isinf(neg_protos)):
            print('neg_protos is nan or inf')
    
    

    eps = 1e-8
    z = F.normalize(z, dim=-1, eps=eps).to('cuda')
    pos = F.normalize(pos_proto, dim=-1, eps=eps).to('cuda')
    pos_logit = (z * pos).sum(-1, keepdim=True) / temperature  # [B,1]
    pos_logit = torch.clamp(pos_logit, min=-100, max=100)
    if neg_protos is None:
        logits = pos_logit
        labels = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
        return F.cross_entropy(logits, labels)

    neg = F.normalize(neg_protos.to('cuda').detach(), dim=-1, eps=eps)  # [B,K,Z]

    neg_logits = torch.einsum("bd,bkd->bk", z, neg) / temperature
    #neg_logits = (z * neg).sum(-1, keepdim=True) / temperature
    neg_logits = torch.clamp(neg_logits, min=-100, max=100)


    logits = torch.cat([pos_logit, neg_logits], dim=1)  # [B, 1+K]
    labels = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
    result = F.cross_entropy(logits, labels)
    if torch.any(torch.isnan(result)) or torch.any(torch.isinf(result)):
        print('result is nan or inf')
    return result
