import torch
from torch import nn
from umfavi.loglikelihoods.base import BaseLogLikelihood


class PreferenceDecoder(BaseLogLikelihood):
    """Head for predicting preferences.
    
    Args:
        normalize_by_length: If True, use mean reward instead of sum to prevent
            logit saturation with long segments. Recommended for segment_len > 10.
        enable_diagnostics: If True, periodically log diagnostic information.
    """

    def __init__(self, normalize_by_length: bool = True, enable_diagnostics: bool = True):
        super().__init__()
        self.normalize_by_length = normalize_by_length
        self.enable_diagnostics = enable_diagnostics

    def forward(self, reward_samples: torch.Tensor, prefs: torch.Tensor, beta: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
        
        # Mask invalid timesteps before aggregating rewards
        masked_rewards = reward_samples * valid  # (batch_size, 2, num_steps)
        
        if self.normalize_by_length:
            # Use MEAN reward to prevent logit saturation with long segments
            # This keeps logits in a reasonable range regardless of segment length
            valid_counts = valid.sum(dim=-1).clamp(min=1)  # (batch_size, 2)
            agg_r_per_traj = masked_rewards.sum(dim=-1) / valid_counts  # (batch_size, 2)
        else:
            # Original: use SUM of rewards (can cause saturation with long segments)
            agg_r_per_traj = torch.sum(masked_rewards, dim=-1)  # (batch_size, 2)
        
        # Compute preference logits and loss
        logits = beta * (agg_r_per_traj[..., 0] - agg_r_per_traj[..., 1])
        loss = nn.functional.binary_cross_entropy_with_logits(logits, prefs, reduction='mean')
        
        return loss
