import torch
from torch.autograd import Function

class IMLETopK(Function):
    @staticmethod
    def forward(ctx, scores, k, r, lambda_val, noise, device):
        ctx.k = k
        ctx.r = r
        ctx.lambda_val = lambda_val
        ctx.device = device
        
        perturbed_scores = scores + noise
        
        top_k_vals, top_k_indices = torch.topk(perturbed_scores, k + r, dim=1)
        
        mask = torch.zeros_like(scores, dtype=torch.float32)
        mask.scatter_(1, top_k_indices[:, :k], 1.0)
        
        ctx.save_for_backward(scores, noise, top_k_indices, top_k_vals)
        return mask

    @staticmethod
    def backward(ctx, grad_output):
        scores, noise, top_kr_indices, top_kr_vals = ctx.saved_tensors
        k, r, lambda_val = ctx.k, ctx.r, ctx.lambda_val
        
        grad_scores = grad_output.clone()
        target_scores = scores - lambda_val * grad_scores
        
        perturbed_target = target_scores + noise
        
        # Adaptive Screening
        vk = top_kr_vals[:, k-1]
        vkr = top_kr_vals[:, k+r-1]
        gap = vk - vkr
        delta = torch.max(torch.abs(lambda_val * grad_scores), dim=1)[0]
        
        use_subset = (delta <= 0.5 * gap)
        
        s_prime = torch.zeros_like(scores)
        
        # Full solve fallback
        if not torch.all(use_subset):
            full_indices = torch.nonzero(~use_subset).squeeze()
            if full_indices.dim() == 0: full_indices = full_indices.unsqueeze(0)
            _, idx_full = torch.topk(perturbed_target[full_indices], k, dim=1)
            s_prime[full_indices] = s_prime[full_indices].scatter(1, idx_full, 1.0)
            
        # Screened solve
        if torch.any(use_subset):
            sub_indices = torch.nonzero(use_subset).squeeze()
            if sub_indices.dim() == 0: sub_indices = sub_indices.unsqueeze(0)
            
            subset_scores = perturbed_target[sub_indices].gather(1, top_kr_indices[sub_indices])
            _, idx_sub_local = torch.topk(subset_scores, k, dim=1)
            idx_sub_global = top_kr_indices[sub_indices].gather(1, idx_sub_local)
            
            s_prime[sub_indices] = s_prime[sub_indices].scatter(1, idx_sub_global, 1.0)

        # Output reconstruction
        perturbed_forward = scores + noise
        _, idx_fwd = torch.topk(perturbed_forward, k, dim=1)
        s_forward = torch.zeros_like(scores).scatter_(1, idx_fwd, 1.0)
        
        gradient = (s_forward - s_prime) / lambda_val
        
        return gradient, None, None, None, None, None

def update_lambda(lambda_curr, grad_s, q=0.9, beta=0.9, lambda_min=1e-2, lambda_max=10.0):
    grad_mag = torch.abs(grad_s).view(-1)
    quantile_val = torch.quantile(grad_mag, q)
    
    # Simple heuristic for lambda update
    s_t = quantile_val if lambda_curr is None else beta * lambda_curr + (1 - beta) * quantile_val
    new_lambda = 1.0 / (s_t + 1e-6)
    return torch.clamp(new_lambda, lambda_min, lambda_max)