import torch
from torch import nn
from umfavi.loglikelihoods.base import BaseLogLikelihood


class RankingDecoder(BaseLogLikelihood):
    """Head for computing Plackett-Luce ranking log-likelihood.
    
    The Plackett-Luce model defines the probability of a ranking as:
    
        P(sigma | r) = prod_{i=1}^{k} exp(beta * r_{sigma_i}) / sum_{j=i}^{k} exp(beta * r_{sigma_j})
    
    This decomposes into sequential softmax choices: at each position i, we predict
    which of the remaining items (positions i to k) should be ranked next.
    
    The negative log-likelihood is:
    
        -log P(sigma | r) = sum_{i=1}^{k} [-beta * r_{sigma_i} + logsumexp(beta * r_{remaining})]
    
    Args:
        normalize_by_length: If True, use mean reward instead of sum to prevent
            logit saturation with long segments. Recommended for segment_len > 10.
    """

    def __init__(self, normalize_by_length: bool = True):
        super().__init__()
        self.normalize_by_length = normalize_by_length

    def forward(
        self, 
        reward_samples: torch.Tensor, 
        ranks: torch.Tensor, 
        beta: torch.Tensor, 
        valid: torch.Tensor
    ) -> torch.Tensor:
        """Compute Plackett-Luce negative log-likelihood.
        
        Args:
            reward_samples: Predicted rewards, shape (batch_size, k, segment_len)
            ranks: Rank assignments for each segment, shape (batch_size, k)
                   where rank 0 = best and rank k-1 = worst
            beta: Rationality parameter, shape (batch_size,) or scalar
            valid: Validity mask, shape (batch_size, k, segment_len)
        
        Returns:
            Mean negative log-likelihood loss (scalar)
        """
        batch_size, k, segment_len = reward_samples.shape
        
        # Mask invalid timesteps before aggregating rewards
        masked_rewards = reward_samples * valid  # (batch_size, k, segment_len)
        
        if self.normalize_by_length:
            # Use MEAN reward to prevent logit saturation with long segments
            valid_counts = valid.sum(dim=-1).clamp(min=1)  # (batch_size, k)
            agg_r = masked_rewards.sum(dim=-1) / valid_counts  # (batch_size, k)
        else:
            # Use SUM of rewards
            agg_r = masked_rewards.sum(dim=-1)  # (batch_size, k)
        
        # Convert ranks to permutation (ordering)
        # ranks[b, i] = rank of item i in batch b (0 = best)
        # permutation[b, r] = index of item with rank r in batch b
        # argsort(ranks) gives us: for each rank position, which item index has that rank
        permutation = ranks.argsort(dim=-1)  # (batch_size, k)
        
        # Ensure beta has the right shape for broadcasting
        if beta.dim() == 0:
            beta = beta.unsqueeze(0)  # (1,)
        beta = beta.view(-1, 1)  # (batch_size, 1)
        
        # Compute Plackett-Luce NLL
        # For each position i in the ranking, we compute:
        # -beta * r_{sigma_i} + logsumexp(beta * r_{remaining items})
        nll = torch.zeros(batch_size, device=reward_samples.device)
        
        for i in range(k):
            # Get the index of the item at position i (the i-th best item)
            item_at_position_i = permutation[:, i]  # (batch_size,)
            
            # Get reward of that item
            r_i = agg_r.gather(-1, item_at_position_i.unsqueeze(-1)).squeeze(-1)  # (batch_size,)
            
            # Get rewards of all remaining items (positions i to k-1)
            remaining_positions = permutation[:, i:]  # (batch_size, k-i)
            r_remaining = agg_r.gather(-1, remaining_positions)  # (batch_size, k-i)
            
            # Compute contribution to NLL:
            # -beta * r_i + logsumexp(beta * r_remaining)
            nll += -beta.squeeze(-1) * r_i + torch.logsumexp(beta * r_remaining, dim=-1)
        
        return nll.mean()
