import numbers
from typing import Sequence, Tuple, Union

def _dim(x) -> int:
    """compatible with 0-1 masks or integer dimensions"""
    if isinstance(x, numbers.Number):
        return int(x)
    return int(getattr(x, "sum", lambda: x)( ))   # torch / np / list all work

def estimate_pruned_params(
    vectors: Sequence[Union[int, float]],          # like [attn_in, attn_out, mlp_in, mlp_mid, mlp_out] × L
    *,
    hidden_size: int        = 4096,                # original hidden dimension
    intermediate_size: int  = 11008,               # original FFN expansion dimension
    num_layers: int         = 32,                  # Transformer blocks
    vocab_size: int         = 32000,               # vocabulary size
    add_embeds: bool        = True,                # whether to include token embed
    add_lm_head: bool       = False,               # if you have extra lm_head
    original_param_cnt: int = 6738996480           # Llama-2-7B total FP weights (≈6.739 B, can be replaced)
) -> Tuple[int, float]:
    """
    return (pruned total parameters, retention ratio); if original_param_cnt provided, also give pruning ratio
    """
    expected_len = num_layers * 5
    if len(vectors) != expected_len:
        raise ValueError(
            f"vectors length should be {expected_len} (= num_layers × 5), "
            f"but got {len(vectors)}, confirm order: attn_in, attn_out, mlp_in, mlp_mid, mlp_out × L"
        )

    idx = 0
    total = 0

    for _ in range(num_layers):
        attn_in  = _dim(vectors[idx]);   idx += 1
        attn_out = _dim(vectors[idx]);   idx += 1
        mlp_in   = _dim(vectors[idx]);   idx += 1
        mlp_mid  = _dim(vectors[idx]);   idx += 1
        mlp_out  = _dim(vectors[idx]);   idx += 1

        # ─── Multi-Head Attention ──────────────────────────────────────────────────
        # q_proj, k_proj, v_proj:  attn_in × hidden_size × 3
        # o_proj:                  attn_out × hidden_size
        total += attn_in * hidden_size * 3 + attn_out * hidden_size

        # ─── Feed-Forward (SwiGLU 3×linear) ─────────────────────────────────────────
        # W1, W2:  mlp_in  × mlp_mid      (two parallel paths)
        # W3:      mlp_mid × mlp_out
        total += mlp_in * mlp_mid * 2 + mlp_mid * mlp_out

    # ─── Embedding & Language Model Head ────────────────────────────────────────────────────────
    if add_embeds:
        total += vocab_size * hidden_size          # token embedding
    if add_lm_head:
        total += vocab_size * hidden_size          # generally don't double count when shared with embedding weights

    keep_ratio = 1 - total / original_param_cnt
    return total, keep_ratio
