import torch
from torch import nn
from umfavi.loglikelihoods.base import BaseLogLikelihood

class DemonstrationsDecoder(BaseLogLikelihood):

    def __init__(self, actions_discrete: bool = True) -> None:
        self.actions_discrete = actions_discrete
        super().__init__()

    def forward(self, acts_curr: torch.Tensor, q_curr: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
        """
        Computes the NLL of demonstrations under a Boltzmann-rational policy.

        For discrete actions:
            π(a | s) ∝ exp(Q(s, a))

        For continuous actions:
            We use the negative Q-value as the negative log-likelihood.
            Since Q(s,a) represents the value of taking action a in state s,
            -Q(s,a) serves as the NLL for that demonstrated action.

        Args:
            acts_curr: Actions tensor
                - Discrete: shape (batch_size, 1) with integer action indices
                - Continuous: shape (batch_size, action_dim) with continuous actions
            q_curr: Q-values tensor
                - Discrete: shape (batch_size, n_actions) - Q-values for all actions
                - Continuous: shape (batch_size, 1) - Q-value for the specific (s,a) pair
            valid: Boolean mask of shape (batch_size,) or (batch_size, 1)

        Returns:
            Scalar NLL averaged over valid entries only
        """
        if self.actions_discrete:
            # Cross-entropy per sample (no reduction)
            # For discrete actions, q_curr has shape (batch_size, n_actions)
            logits = q_curr
            nll = nn.functional.cross_entropy(logits, acts_curr.squeeze(-1), reduction='none')
        else:
            # For continuous actions, q_curr has shape (batch_size, 1)
            # The Q-value represents the log-probability (up to normalization)
            # So NLL = -Q(s, a)
            nll = -q_curr.squeeze(-1)

        # Mask and average over valid entries only
        valid_flat = valid.view(-1) if valid.dim() > 1 else valid
        nll_masked = nll * valid_flat
        return nll_masked.sum() / valid_flat.sum().clamp(min=1)