import torch
from torch import nn
import torch.nn.functional as F
from umfavi.loglikelihoods.base import BaseLogLikelihood


class RatingDecoder(BaseLogLikelihood):
    """Head for predicting ordinal ratings using ordered logit model.
    
    Uses the cumulative link model with logistic (sigmoid) link function.
    
    P(y = k | r) = sigma(theta_k - r) - sigma(theta_{k-1} - r)
    
    where theta_0 = -inf and theta_K = +inf by convention.
    
    To ensure theta_1 < theta_2 < ... < theta_{K-1}, we parameterize:
    - theta_1 = raw_theta_1 (free parameter)
    - theta_k = theta_{k-1} + softplus(delta_k) for k > 1
    
    Args:
        num_categories: Number of ordinal categories (K). Default is 5 for Likert scale.
        normalize_by_length: If True, use mean reward instead of sum.
        enable_diagnostics: If True, periodically log diagnostic information.
    """

    def __init__(
        self, 
        num_categories: int = 5,
        normalize_by_length: bool = True, 
        enable_diagnostics: bool = True
    ):
        super().__init__()
        self.num_categories = num_categories
        self.normalize_by_length = normalize_by_length
        self.enable_diagnostics = enable_diagnostics
        
        # K-1 cutpoints for K categories
        # Parameterize as: theta_1 (free), then K-2 increments (softplus ensures positivity)
        self.raw_theta_1 = nn.Parameter(torch.tensor(0.0))
        if num_categories > 2:
            # Initialize increments to give roughly unit spacing after softplus
            self.delta_increments = nn.Parameter(torch.zeros(num_categories - 2))
        else:
            self.register_buffer('delta_increments', torch.empty(0))
    
    def get_cutpoints(self) -> torch.Tensor:
        """Compute ordered cutpoints from learnable parameters.
        
        Returns:
            Tensor of shape (K-1,) with strictly increasing cutpoints.
        """
        if self.num_categories == 2:
            return self.raw_theta_1.unsqueeze(0)
        
        # theta_1 is the first cutpoint
        theta_1 = self.raw_theta_1
        
        # Remaining cutpoints: theta_k = theta_{k-1} + softplus(delta_k)
        increments = F.softplus(self.delta_increments)
        cumulative_increments = torch.cumsum(increments, dim=0)
        
        # Stack all cutpoints
        cutpoints = torch.cat([
            theta_1.unsqueeze(0),
            theta_1 + cumulative_increments
        ])
        
        return cutpoints

    def forward(
        self, 
        reward_samples: torch.Tensor, 
        ratings: torch.Tensor, 
        valid: torch.Tensor
    ) -> torch.Tensor:
        """Compute negative log-likelihood of observed ratings.
        
        Args:
            reward_samples: Reward estimates of shape (batch_size, segment_len)
            ratings: Ordinal ratings of shape (batch_size,), values in [0, K-1]
            valid: Boolean mask of shape (batch_size, segment_len)
        
        Returns:
            Scalar NLL averaged over valid entries.
        """
        
        # Mask invalid timesteps before aggregating rewards
        masked_rewards = reward_samples * valid  # (batch_size, segment_len)
        
        if self.normalize_by_length:
            # Use MEAN reward to prevent saturation
            valid_counts = valid.sum(dim=-1).clamp(min=1)  # (batch_size,)
            agg_rewards = masked_rewards.sum(dim=-1) / valid_counts  # (batch_size,)
        else:
            agg_rewards = masked_rewards.sum(dim=-1)  # (batch_size,)
        
        # Get ordered cutpoints
        cutpoints = self.get_cutpoints()  # (K-1,)
        
        # Compute cumulative probabilities: P(y <= k) = sigma(theta_k - r)
        # Shape: (batch_size, K-1)
        cum_probs = torch.sigmoid(cutpoints.unsqueeze(0) - agg_rewards.unsqueeze(1))
        
        # Compute category probabilities: P(y = k) = P(y <= k) - P(y <= k-1)
        # P(y = 0) = P(y <= 0) = sigma(theta_1 - r)
        # P(y = k) = sigma(theta_{k+1} - r) - sigma(theta_k - r) for k in [1, K-2]
        # P(y = K-1) = 1 - P(y <= K-2) = 1 - sigma(theta_{K-1} - r)
        
        # Pad with 0 at start (P(y <= -1) = 0) and 1 at end (P(y <= K-1) = 1)
        zeros = torch.zeros(cum_probs.shape[0], 1, device=cum_probs.device)
        ones = torch.ones(cum_probs.shape[0], 1, device=cum_probs.device)
        padded_cum_probs = torch.cat([zeros, cum_probs, ones], dim=1)  # (batch_size, K+1)
        
        # P(y = k) = P(y <= k) - P(y <= k-1)
        probs = padded_cum_probs[:, 1:] - padded_cum_probs[:, :-1]  # (batch_size, K)
        
        # Clamp for numerical stability
        probs = probs.clamp(min=1e-8)
        
        # Gather probabilities for observed ratings
        ratings_long = ratings.long()
        log_probs = torch.log(probs.gather(1, ratings_long.unsqueeze(1))).squeeze(1)
        
        # Average NLL
        loss = -log_probs.mean()
        
        return loss

