from __future__ import annotations

from typing import Tuple

import torch

try:
    import torchsort
except Exception:
    torchsort = None



def hard_topk_mask(scores: torch.Tensor, k: int) -> torch.Tensor:
    """Return a binary mask indicating the top-k entries along the last dimension."""
    if k <= 0:
        raise ValueError("k must be positive.")
    if k >= scores.size(-1):
        return torch.ones_like(scores, dtype=scores.dtype)

    topk_indices = torch.topk(scores, k, dim=-1).indices
    mask = torch.zeros_like(scores, dtype=scores.dtype)
    mask.scatter_(-1, topk_indices, 1.0)
    return mask


def _validate_scores(scores: torch.Tensor) -> None:
    if scores.dim() not in (2, 3):
        raise ValueError(f"scores must be 2D or 3D, got shape {scores.shape}.")


def _alpha_shape(scores: torch.Tensor, alpha_granularity: str) -> Tuple[int, ...]:
    if scores.dim() == 2:
        return (scores.size(0), 1)
    if alpha_granularity == "layer":
        return (1, scores.size(1), 1)
    if alpha_granularity == "batch_layer":
        return (scores.size(0), scores.size(1), 1)
    raise ValueError(f"Unsupported alpha_granularity: {alpha_granularity}")


def sigmoid_normalize(
    scores: torch.Tensor,
    k: int,
    *,
    alpha_granularity: str = "layer",
    max_iter: int = 50,
    tol: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute lambda = sigmoid(scores - alpha) with sum(lambda)=k along the last dim."""
    if k <= 0:
        raise ValueError("k must be positive.")
    _validate_scores(scores)

    last_dim = scores.size(-1)
    if k >= last_dim:
        alpha_shape = _alpha_shape(scores, alpha_granularity)
        alpha = torch.full(alpha_shape, float("-inf"), device=scores.device, dtype=scores.dtype)
        return torch.ones_like(scores, dtype=scores.dtype), alpha

    orig_dtype = scores.dtype
    scores_fp32 = scores.float()

    if scores_fp32.dim() == 2:
        scores_for_alpha = scores_fp32
    else:
        if alpha_granularity == "layer":
            scores_for_alpha = scores_fp32.mean(dim=0)
        elif alpha_granularity == "batch_layer":
            scores_for_alpha = scores_fp32
        else:
            raise ValueError(f"Unsupported alpha_granularity: {alpha_granularity}")

    margin = 20.0
    low = scores_for_alpha.min(dim=-1, keepdim=True).values - margin
    high = scores_for_alpha.max(dim=-1, keepdim=True).values + margin
    target = float(k)

    for _ in range(max_iter):
        mid = (low + high) / 2
        lambda_mid = torch.sigmoid(scores_for_alpha - mid)
        sum_mid = lambda_mid.sum(dim=-1, keepdim=True)
        low = torch.where(sum_mid > target, mid, low)
        high = torch.where(sum_mid > target, high, mid)
        if (sum_mid - target).abs().max().item() <= tol:
            break

    alpha = (low + high) / 2
    if scores_fp32.dim() == 2:
        alpha_broadcast = alpha
    else:
        alpha_broadcast = alpha.unsqueeze(0) if alpha_granularity == "layer" else alpha

    lambda_scores = torch.sigmoid(scores_fp32 - alpha_broadcast)
    lambda_scores = lambda_scores.to(orig_dtype)
    alpha_broadcast = alpha_broadcast.to(orig_dtype)
    return lambda_scores, alpha_broadcast


def soft_topk(
    scores: torch.Tensor,
    k: int,
    *,
    alpha_granularity: str = "layer",
    max_iter: int = 50,
    tol: float = 1e-4,
    regularization_strength: float = 1.0,
) -> torch.Tensor:
    """SoftTopK via sigmoid normalization + hard TopK masking.

    regularization_strength is kept for backward compatibility and is unused.
    """
    _ = regularization_strength
    lambda_scores, _ = sigmoid_normalize(
        scores,
        k,
        alpha_granularity=alpha_granularity,
        max_iter=max_iter,
        tol=tol,
    )
    hard_mask = hard_topk_mask(lambda_scores, k)
    return lambda_scores * hard_mask


def soft_topk_torchsort(
    scores: torch.Tensor,
    k: int,
    *,
    regularization_strength: float = 1.0,
) -> torch.Tensor:
    """Legacy torchsort-based SoftTopK."""
    if k <= 0:
        raise ValueError("k must be positive.")
    if torchsort is None:
        raise ImportError(
            "torchsort is required for soft_topk_torchsort. Install with `pip install torchsort`."
        )

    last_dim = scores.size(-1)
    orig_dtype = scores.dtype
    flat_scores = scores.reshape(-1, last_dim).float()

    ranks = torchsort.soft_rank(
        (-flat_scores),
        regularization="l2",
        regularization_strength=regularization_strength,
    )
    soft_k = torch.relu(k - ranks)
    soft_k_sum = soft_k.sum(dim=-1, keepdim=True)

    normalized = k * soft_k / (soft_k_sum + 1e-9)
    normalized = normalized.reshape_as(scores)
    if normalized.dtype != orig_dtype:
        normalized = normalized.to(orig_dtype)
    return normalized
