import torch
from umfavi.loglikelihoods.base import BaseLogLikelihood


class StopDecoder(BaseLogLikelihood):
    """
    Decoder for stop feedback using a discrete-time hazard model.
    
    The hazard at time t is h_t = 1 - exp(-lambda * R_t) where R_t is the
    discounted cumulative regret up to time t.
    
    Discounted cumulative regret: R_t = regret_discount * R_{t-1} + r_t
    This allows old regret to decay over time.
    
    NLL for stop at time tau:
        -log P(T=tau) = -sum_{t<tau} log(1-h_t) - log(h_tau)
    
    For censored observations (no stop), only survival likelihood:
        -log S(T) = -sum_{t=1}^T log(1-h_t)
    """
    
    def __init__(self):
        super().__init__()
    
    def forward(
        self,
        q_values: torch.Tensor,
        actions: torch.Tensor,
        stop_times: torch.Tensor,
        lambd: torch.Tensor,
        regret_discount: torch.Tensor,
        valid: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute negative log-likelihood of observed stop times.
        
        Args:
            q_values: Q-values for all actions, shape (batch, seq_len, n_actions)
            actions: Taken actions, shape (batch, seq_len)
            stop_times: Stop timesteps, shape (batch,). -1 indicates censored (no stop)
            lambd: Stop sensitivity parameter, scalar or shape (batch,)
            regret_discount: Discount factor for old regret (0-1), scalar or shape (batch,)
            valid: Validity mask, shape (batch, seq_len)
        
        Returns:
            Scalar NLL averaged over batch.
        """
        batch_size, seq_len, n_actions = q_values.shape
        device = q_values.device
        
        # Compute instantaneous regret: max_a Q(s,a) - Q(s, a_taken)
        q_max, _ = q_values.max(dim=-1)  # (batch, seq_len)
        actions_long = actions.long()  # (batch, seq_len, 1)
        q_taken = q_values.gather(dim=-1, index=actions_long).squeeze(-1)  # (batch, seq_len)
        instant_regret = q_max - q_taken  # (batch, seq_len)
        
        # Mask invalid timesteps
        instant_regret = instant_regret * valid
        
        # Discounted cumulative regret: R_t = regret_discount * R_{t-1} + r_t
        if regret_discount.dim() == 0:
            regret_discount = regret_discount.unsqueeze(0)
        discount = regret_discount.view(-1)  # (batch,)
        
        cum_regret = torch.zeros(batch_size, seq_len, device=device)
        cum_regret[:, 0] = instant_regret[:, 0]
        for t in range(1, seq_len):
            cum_regret[:, t] = discount * cum_regret[:, t-1] + instant_regret[:, t]
        
        # Hazard: h_t = 1 - exp(-lambda * R_t)
        if lambd.dim() == 0:
            lambd = lambd.unsqueeze(0)
        hazard = 1.0 - torch.exp(-lambd.unsqueeze(-1) * cum_regret)  # (batch, seq_len)
        
        # Clamp for numerical stability
        hazard = hazard.clamp(min=1e-8, max=1.0 - 1e-8)
        
        # Log survival: log(1 - h_t)
        log_survival = torch.log(1.0 - hazard)  # (batch, seq_len)
        
        # Log hazard: log(h_t)
        log_hazard = torch.log(hazard)  # (batch, seq_len)
        
        # Vectorized NLL computation
        time_indices = torch.arange(seq_len, device=device).unsqueeze(0)  # (1, seq_len)
        
        # This works since no stop is -1 (censoring)
        is_stopped = stop_times >= 0  # (batch,)
        stop_times_clamped = stop_times.clamp(min=0).long()  # (batch,)
        
        # Survival mask: (t < tau AND valid) for stopped, valid mask for censored
        # We skip invalid timesteps in simulation, so exclude them from survival probability too
        survival_mask_stopped = (time_indices < stop_times_clamped.unsqueeze(1)) & valid  # (batch, seq_len)
        survival_mask = torch.where(is_stopped.unsqueeze(1), survival_mask_stopped, valid)
        
        # Survival NLL: -sum of log_survival over masked positions
        survival_nll = -(log_survival * survival_mask).sum(dim=-1)  # (batch,)
        
        # Hazard NLL: -log_hazard at tau for stopped samples only
        log_hazard_at_tau = log_hazard.gather(dim=-1, index=stop_times_clamped.unsqueeze(-1)).squeeze(-1)
        hazard_nll = -log_hazard_at_tau * is_stopped.float()  # Zero for censored
        
        nll = survival_nll + hazard_nll
        return nll.mean()
