import torch
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class SMArgs:
    """Arguments for Softmasking"""

    # sm algorithm
    sm_alg: str = "none" # "mixinputs_with_topk" or "mixinputs_with_temp"
    sm_schedule: str = "none"  # "none", "linear", or "stepwise"

    # lambda(·) parameters
    scale: float = 0.0          # overall strength of mixing (0 disables mixing)
    steepness: float = 0.0      # sigmoid steepness for entropy->lambda map
    offset: float = 0.0         # sigmoid offset entropy->lambda map
    
    # used only when sm_alg == "mixinputs_with_topk"
    mixinputs_k: int = 3
    # used only when sm_alg == "mixinputs_with_temp"
    mixinputs_temp: float = 1.0

def get_mixing_factors_for_softmasking(input_ids, logits_prelim, mask_token_id, max_gen_length, sm_args):
    """Compute mixing factors and output probabilities for Softmasking."""

    # Create a one-hot distribution for the original input `xt`.
    xt_one_hot = F.one_hot(input_ids, num_classes=logits_prelim.shape[-1]).to(logits_prelim.dtype)
    
    # First get the negative entropy to calculate lambda
    temperature = sm_args.mixinputs_temp if sm_args.sm_alg == "mixinputs_with_temp" else 1.0
    neg_entropy, p = get_neg_entropy_and_probabilities(logits_prelim, temperature=temperature)

    # Update scale with schedule if needed
    if sm_args.sm_schedule != "none":
        num_mask_token = (input_ids == mask_token_id).sum().item()
        scale = get_time_dependence(
            max_gen_length=max_gen_length,
            num_mask_token=num_mask_token,
            scale=sm_args.scale,
            schedule=sm_args.sm_schedule
        )
    else:
        scale = sm_args.scale

    # Calculate lambda tensor
    mask_positions = (input_ids == mask_token_id)
    lambda_tensor = calculate_lambda_tensor(neg_entropy, mask_positions, 
                                            scale, sm_args.steepness, sm_args.offset)

    if sm_args.sm_alg == "mixinputs_with_topk":
        # Only fill probabilities for top-k tokens
        p = get_only_topk_probs(logits_prelim, sm_args.mixinputs_k)

    # Create convex combination for output probabilities
    p_out = (1 - lambda_tensor) * xt_one_hot \
                + lambda_tensor * p

    return p_out

def get_neg_entropy_and_probabilities(logits, temperature=1.0):
    """Get negative entropy and probabilities from logits"""

    epsilon = 1e-10
    p = torch.softmax(logits / temperature, dim=-1)   # (B,T,V)
    logp = torch.log(p + epsilon)
    neg_entropy = torch.sum(p * logp, dim=-1)
    return neg_entropy, p

def calculate_lambda_tensor(neg_entropy, mask_positions, scale, steepness, offset):
    """Calculate lambda tensor from negative entropy"""
    
    if neg_entropy is None or scale == 0.0:
        return torch.zeros_like(neg_entropy)
    
    # scale negative entropy to [0,1] using sigmoid
    lambda_tensor = neg_entropy
    lambda_tensor = scale * torch.sigmoid(steepness * (lambda_tensor - offset))

    # apply only on mask positions
    lambda_tensor = torch.where(mask_positions, lambda_tensor, torch.zeros_like(lambda_tensor))
    return lambda_tensor.unsqueeze(-1) # (B,T,1)

def get_only_topk_probs(logits, mixinputs_k=3):
    """Compute a full-vocabulary probability tensor where only the top-k tokens per position
   receive softmax probabilities and all other entries are zero."""

    topk_logits, topk_indices = torch.topk(logits, k=mixinputs_k, dim=-1)  # (batch_size, seq_len, k)

    topk_probs = torch.softmax(topk_logits, dim=-1)  # (batch_size, seq_len, k)
    topk_sum = topk_probs.sum(dim=-1)  # (batch_size, seq_len)
    assert torch.allclose(topk_sum, torch.ones_like(topk_sum), atol=1e-1), \
        f"Top-k softmax probabilities do not sum to 1: max deviation = {(topk_sum - 1).abs().max().item()}"

    probs_full = torch.zeros_like(logits)                                 # (B, L, V)
    probs_full.scatter_(-1, topk_indices, topk_probs)                     # fill top-k
    assert torch.sum(probs_full > 0).item() == mixinputs_k * logits.shape[0] * logits.shape[1], \
        f"Number of non-zero entries in probs_full is incorrect: got {torch.sum(probs_full > 0).item()}, expected {mixinputs_k * logits.shape[0] * logits.shape[1]}"
    
    return probs_full

def get_time_dependence(
    max_gen_length: int,
    num_mask_token: int,
    scale: float,
    schedule: str,
    sm_to_hm: bool = True,
    threshold: float = 0.5,
) -> float:
    """Return scale factor depending on decoding progress."""
    t = num_mask_token / max_gen_length if max_gen_length else 1.0

    if schedule == "none":
        return scale

    if schedule == "linear":
        return scale * (t if sm_to_hm else 1 - t)

    if schedule == "stepwise":
        cond = t > threshold if sm_to_hm else t < threshold
        return scale if cond else 0

    raise ValueError(f"Unknown schedule: {schedule}")
