from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import torch


@dataclass(frozen=True)
class ZSICConfig:
    target_rate_bits: float
    sic_variant: str = "compress_w2q"
    apply_tgamma: bool = True
    tgamma_ridge: float = 1e-6
    tgamma_max_iter: int = 1000
    tgamma_tol: float = 3e-4
    include_overheads: bool = True
    overhead_bits_per_param: int = 16
    percdamp: float = 0.0001
    cholesky_max_tries: int = 6
    cholesky_eps: float = 1e-6
    cholesky_growth: float = 10.0
    binary_search: bool = False
    binary_search_iters: int = 15
    binary_search_left: float = -10.0
    binary_search_right: float = 10.0
    binary_search_row_fraction: float = 0.1
    qronos: bool = False
    # Residual compensation for wo/w2 layers (layers that output to residual stream)
    # When enabled, modifies the quantization target to account for residual stream error:
    # ŷ = (W Σ_{X,X̂} + Σ_{ΔR,X̂}) (L̂^T)^{-1}  where Σ_{ΔR,X̂} = E[(R - R̂)X̂^T]
    residual_compensation: bool = False


def compute_entropy(zdata: torch.Tensor) -> float:
    """Compute log2-based entropy of a tensor."""
    Zsic = zdata.flatten()
    zsic_elts, zsic_counts = torch.unique(Zsic.flatten(), return_counts=True)
    probs = zsic_counts.float() / Zsic.numel()
    entropy = -torch.sum(probs * torch.log2(probs))
    return entropy.item()


@torch.no_grad()
def find_optimal_rescalers3(
    W_hat: torch.Tensor,   # shape: a x n
    W: torch.Tensor,       # shape: a x n
    Sig_X: torch.Tensor,    # shape: n x n, = E[X X^T]
    Sig_hX: torch.Tensor = None, # shape: n x n, = E[\hat X \hat X^T]
    Sig_X_hX: torch.Tensor = None, # shape: n x n, = E[ X \hat X^T]
    Sig_delta_R_Xhat: torch.Tensor = None,  # shape: a x n, = E[(R - R̂) X̂^T] for residual compensation
    max_iter: int = 1000,
    tol: float = 3e-4,
    ridge_eps: float = 1e-11,    # small Tikhonov for Gamma-step and T-step
    t_init: torch.Tensor = None,
    gamma_init: torch.Tensor = None,
    gamma_clip_min: float = -0.1,  # Clip gamma to prevent wild values at low rates
    gamma_clip_max: float = 1.5,   # (columns quantized to ~0 can cause gamma blowup)
    quiet: bool = False,
):
    """
    Alternating updates for diagonal T and Gamma that minimize:
        J(T,Gamma) = -2 tr(T W_hat Gamma SigX W) + tr(T W_hat Gamma SigX Gamma W_hat^T T)

    When Sig_delta_R_Xhat is provided (for residual compensation), replaces
    W @ Sig_X_hX with (W @ Sig_X_hX + Sig_delta_R_Xhat) in the optimization.

    Conventions:
      - a := number of rows of W_hat
      - n := number of columns of W_hat (and size of SigX)
      - T  = diag(t) with t in R^a
      - Gamma = diag(gamma) with gamma in R^n
      - Normalization: t.abs().sum() == a at every iteration (scale absorbed into Gamma).
      - Progress printing uses mse_loss(T @ W_hat @ Gamma) as requested.

    Returns:
      T (axa, diagonal), Gamma (nxn, diagonal)
    """

    # ----- basic checks & shape harmonization -----
    a, n = W_hat.shape
    assert W.shape == (a, n)
    assert Sig_X.shape == (n, n)

    if Sig_hX is None:
        assert Sig_X_hX is None
        # Assuming that X=\hat X (e.g. first layer)
        Sig_hX = Sig_X
        Sig_X_hX = Sig_X

    device = W_hat.device
    dtype = W_hat.dtype

    assert dtype == torch.double, 'ERROR: this code (at high-rates) does not work in torch.float...'

    # Compute the effective cross-term: W @ Sig_X_hX + Sig_delta_R_Xhat (if provided)
    # This accounts for residual compensation in the T/Gamma optimization
    W_Sig_X_hX_eff = W @ Sig_X_hX
    if Sig_delta_R_Xhat is not None:
        W_Sig_X_hX_eff = W_Sig_X_hX_eff + Sig_delta_R_Xhat.to(dtype=dtype, device=device)
        if not quiet:
            print(f"[find_optimal_rescalers3] applying residual compensation in T/Gamma optimization")

    # mse_loss uses the effective cross-term
    mse_loss = lambda What: torch.trace(W @ Sig_X @ W.T - 2 * W_Sig_X_hX_eff @ What.T + What @ Sig_hX @ What.T) / (n * a)

    # ----- initialization: ones(a), ones(n) -----
    if t_init is None:
        t = torch.ones(a, device=device, dtype=dtype)
    else:
        t = t_init.clone().to(device=device, dtype=dtype)

    if gamma_init is None:
        gamma = torch.ones(n, device=device, dtype=dtype)
    else:
        gamma = gamma_init.clone().to(device=device, dtype=dtype)

    T = torch.diag(t)
    Gamma = torch.diag(gamma)

    # enforce t.abs().sum() = a and absorb scale into Gamma (keeps T @ W_hat @ Gamma unchanged)
    s0 = t.abs().sum() / a
    if s0 > 0:
        t = t / s0
        gamma = gamma * s0
        T = torch.diag(t)
        Gamma = torch.diag(gamma)

    # ----- step 0: report mse_loss -----
    loss_prev = mse_loss(T @ W_hat @ Gamma).detach()
    if not quiet:
        print(f"iter 0 | mse_loss = {float(loss_prev):.6e}")

    for it in range(1, max_iter + 1):

        # ===== Gamma-step (given T) =====
        F2 = W_hat.T @ (T ** 2) @ W_hat                 # nxn
        F3 = Sig_hX * F2                            # nxn (Hadamard product!!)
        # F4 uses the effective cross-term (includes residual compensation if enabled)
        F4 = W_hat.T @ T @ W_Sig_X_hX_eff
        f4_vec = torch.diagonal(F4)                # n

        if ridge_eps > 0.0:
            F5 = F3 + ridge_eps * torch.eye(n, device=device, dtype=dtype)
        else:
            F5 = F3

        try:
            gamma = torch.linalg.solve(F5.double(), f4_vec.double())
        except RuntimeError:
            print('WARNING: linalg.solve() failed, using pseudo-inverse (please set ridge>0.0 for stability!)')
            # Run it in double
            F6 = torch.linalg.pinv(F5.double())
            gamma = F6 @ f4_vec.double()

        gamma = gamma.to(dtype)

        if (it == 1) and (gamma_init is not None) and (t_init is None):
            mean_diff = (gamma - gamma_init).abs().mean()
            mask = gamma_init > 0
            rel_mean_diff = (1 - gamma_init[mask] / gamma[mask]).abs().mean()
            if not quiet:
                print(f'iter 1 | gamma changed by {mean_diff:.5g} (rel = {rel_mean_diff:.5g})')

        Gamma = torch.diag(gamma)

        loss_curr = mse_loss(T @ W_hat @ Gamma)
        rel = torch.abs(loss_curr - loss_prev) / (torch.abs(loss_prev) + 1e-12)

        assert loss_curr < (loss_prev + 1e-12), f"ERROR: (MIDSTEP) this is a min-min procedure but it INCREASED loss. BUG!\n" \
                                                f"(loss_curr={loss_curr}, loss_prev={loss_prev}, it={it}"

        # ===== clip gamma to prevent wild values (columns quantized to ~0 can blow up) =====
        # Clip BEFORE T-step so T is computed for the clipped gamma
        n_clipped_low = (gamma < gamma_clip_min).sum().item()
        n_clipped_high = (gamma > gamma_clip_max).sum().item()
        if (n_clipped_low > 0 or n_clipped_high > 0) and not quiet:
            print(f"iter {it} | clipping gamma: {n_clipped_low} below {gamma_clip_min}, {n_clipped_high} above {gamma_clip_max}")
        gamma = gamma.clamp(min=gamma_clip_min, max=gamma_clip_max)
        Gamma = torch.diag(gamma)

        # ===== T-step (given Gamma) =====
        # F7t uses the effective cross-term (includes residual compensation if enabled)
        F7t = W_Sig_X_hX_eff @ Gamma @ W_hat.T        # axa
        f7_vec = torch.diagonal(F7t)                  # a
        F8 = W_hat @ Gamma @ Sig_hX @ Gamma @ W_hat.T   # axa
        f8_vec = torch.diagonal(F8)                  # a

        if ridge_eps > 0.0:
            t = f7_vec / (f8_vec + ridge_eps)
        else:
            t = f7_vec / f8_vec
            t[(f7_vec == 0) & (f8_vec == 0)] = 0.0
            assert torch.all(t.isfinite()), 'ERROR: T-step divides by 0'
        T = torch.diag(t)

        # ===== normalization: t.abs().sum() = 1, absorb scale into Gamma =====
        s = t.abs().sum() / a
        if float(s) > 0.0:
            t = t / s
            T = torch.diag(t)
            gamma = gamma * s
            Gamma = torch.diag(gamma)

        # ===== report & stopping based on mse_loss changes =====
        loss_curr = mse_loss(T @ W_hat @ Gamma)
        rel = torch.abs(loss_curr - loss_prev) / (torch.abs(loss_prev) + 1e-12)

        # Note: clipping gamma can increase loss (it's a projection, not a descent step)
        # Only assert when no clipping occurred
        clipped_this_iter = (n_clipped_low > 0 or n_clipped_high > 0)
        if not clipped_this_iter:
            assert loss_curr < (loss_prev + 1e-12), f"ERROR: this is a min-min procedure but it INCREASED loss. BUG!\n" \
                                                    f"(loss_curr={loss_curr}, loss_prev={loss_prev}, it={it}"

        if not quiet:
            print(f"iter {it} | mse_loss = {float(loss_curr):.6e} | rel change = {float(rel):.3e}")

        if float(rel) < tol:
            break
        loss_prev = loss_curr

    # Print statistics:
    if not quiet:
        t = T.diag()
        gamma = Gamma.diag()

        def print_stats(tens, name):
            mmin, mmax = tens.flatten().min(), tens.flatten().max()
            q25, q75 = torch.quantile(tens.flatten(), 0.25), torch.quantile(tens.flatten(), 0.75)
            mean, stddev = tens.mean(), torch.std(tens)
            print('Tensor ' + name + f' stats: min={mmin:.3g}, q25={q25:.3g}, mean = {mean:.3g}, q75={q75:.3g}, max={mmax:.3g};  std = {stddev:.3g}')

        print_stats(t, 'row-rescaler T:')
        print_stats(gamma, 'column-rescaler Gamma:')

    return T, Gamma


@torch.no_grad()
def compress_w2q(
    W, Sig_X, target_rate=1.5, quiet=False, Sig_hX=None, Sig_X_hX=None, percdamp=0.0001,
    Sig_delta_R_Xhat=None,  # Residual compensation: Σ_{ΔR,X̂} = E[(R - R̂)X̂^T]
):
    """
    LDLQ with per-column gamma adaptation + T/Gamma rescaler optimization.

    Args:
        W: Weight matrix (a, n)
        Sig_X, Sig_hX, Sig_X_hX: E[XX^T], E[Xhat Xhat^T], E[X Xhat^T]
        target_rate: Target compression rate in bits
        quiet: Suppress print output
        Sig_delta_R_Xhat: Optional residual compensation term, shape (a, n).
                          When provided, modifies target: ŷ = (W Σ_{X,X̂} + Σ_{ΔR,X̂})(L̂^T)^{-1}

    Returns:
        (final_loss, final_rate, What, frame): Loss, rate, reconstructed weights, and locals dict
    """
    a, n = W.shape
    dtype_orig = W.dtype

    # Convert all inputs to double for numerical stability
    W = W.double()
    Sig_X = Sig_X.double()
    if Sig_hX is not None:
        Sig_hX = Sig_hX.double()
    if Sig_X_hX is not None:
        Sig_X_hX = Sig_X_hX.double()

    if Sig_hX is not None:
        assert Sig_X_hX is not None
        H = Sig_hX
        qronos = True
    else:
        H = Sig_X
        qronos = False

    # Add damping to Hessian diagonal
    damp = percdamp * torch.mean(torch.diag(H))
    H_damped = H + damp * torch.eye(n, device=H.device, dtype=H.dtype)
    L = torch.linalg.cholesky(H_damped, upper=False)
    assert torch.all(L.diag() >= 0)

    if qronos:
        # Qronos: Y = (W @ Σ_{X,X̂} + Σ_{ΔR,X̂}) @ L̂^{-T}
        target = W.double() @ Sig_X_hX.double()
        if Sig_delta_R_Xhat is not None:
            if not quiet:
                print(f"[compress_w2q] applying residual compensation (Qronos)")
            target = target + Sig_delta_R_Xhat.double()
        Ycur = torch.linalg.solve_triangular(L.T, target, left=False, upper=True)
    else:
        # LDLQ: Y = W @ L + Σ_{ΔR,X̂} @ L^{-T}
        Ycur = W.double() @ L
        if Sig_delta_R_Xhat is not None:
            if not quiet:
                print(f"[compress_w2q] applying residual compensation")
            Ycur = Ycur + torch.linalg.solve_triangular(
                L, Sig_delta_R_Xhat.double().T, left=True, upper=False
            ).T

    Sw = W.T @ W / W.shape[0]
    target_rate_nats = target_rate * math.log(2)  # in nats
    c_param = torch.exp(torch.log(12 * Sw.diag() * (L.diag() ** 2)).mean() / 2 - target_rate_nats)

    alphas = c_param / L.diag()
    gammas = torch.ones(n, device=W.device)
    column_snr_orig = Sw.diag() / (alphas ** 2)
    column_snr_ldlq = torch.zeros_like(column_snr_orig)

    Zsic = torch.zeros_like(W, dtype=torch.int64, device=W.device)

    ## Perform uneven rate LDLQ
    for col in range(n - 1, -1, -1):
        wcol = Ycur[:, col]
        column_snr_ldlq[col] = (wcol ** 2).mean() / (c_param ** 2)
        zcol = torch.round(wcol / c_param).int()
        Zsic[:, col] = zcol
        f1 = (zcol.double() * wcol).sum()
        f2 = (zcol.double() ** 2).sum()
        if f2 > 0:
            gammas[col] = f1 / f2 / c_param
            corrector = torch.outer(zcol, gammas[col] * alphas[col] * L[col, :])
            assert Ycur.shape == corrector.shape
            Ycur = Ycur - corrector
        else:
            gammas[col] = 0
        # if col in [4054,4055,4056]:
        #    print(f'col = {col}, gamma = {gammas[col]}, f1 = {f1}, f2 = {f2}');

    What_pre = Zsic.double() @ torch.diag(alphas * gammas).double()

    ## Remember to use un-regularized Hessian
    if qronos:
        mse_loss_func = lambda What: torch.trace(W @ Sig_X @ W.T - 2 * W @ Sig_X_hX @ What.T + What @ Sig_hX @ What.T) / (n * a)
    else:
        mse_loss_func = lambda What: torch.trace((What - W) @ Sig_X @ (What - W).T) / (n * a)

    mse_out = mse_loss_func(What_pre)
    mse_null = mse_loss_func(torch.zeros_like(W))
    rel_mse_out = mse_out / mse_null
    if not quiet:
        print(f'Target rate = {math.log2(math.exp(target_rate_nats))}, MSE = {mse_out}, relative_mse = {rel_mse_out}')
        print(f'Zsic: min = {Zsic.min()}, max = {Zsic.max()}, mean = {Zsic.float().mean()}, stddev = {math.sqrt(Zsic.float().var())}')
    zsic_elts, zsic_counts = torch.unique(Zsic.flatten(), return_counts=True)
    probs = zsic_counts.float() / Zsic.numel()
    entropy = -torch.sum(probs * torch.log2(probs))
    if not quiet:
        print(f"Huffman coded compression rate = {entropy + 16 / a} bit/entry.    Zsic entrywise entropy: {entropy.item()} bits")

    ## Now let us optimize diagonal row- and column- scalers.
    if not quiet:
        print('... optimizing diagonal rescalers')
    What_pre0 = Zsic.double() @ torch.diag(alphas)  # remove Gamma multiplier

    # find_optimal_rescalers3 handles optional residual compensation via Sig_delta_R_Xhat
    if qronos:
        T, Gamma = find_optimal_rescalers3(What_pre0, W, Sig_X, gamma_init=gammas, quiet=quiet,
                                            Sig_hX=Sig_hX, Sig_X_hX=Sig_X_hX,
                                            Sig_delta_R_Xhat=Sig_delta_R_Xhat)
    else:
        T, Gamma = find_optimal_rescalers3(What_pre0, W, Sig_X, gamma_init=gammas, quiet=quiet,
                                            Sig_delta_R_Xhat=Sig_delta_R_Xhat)

    What = T @ Zsic.double() @ Gamma @ torch.diag(alphas)
    final_loss = mse_loss_func(What)
    final_rate = entropy + 16 / a + 16 / n
    if not quiet:
        print(f'Final loss: {final_loss:.3g}, Final rate = {final_rate:.3g} bit/entry\n')

    return final_loss, final_rate, What.to(dtype_orig), locals()


# =============================================================================
# Wrapper for pipeline integration - compress_zsic
# =============================================================================

@torch.no_grad()
def compress_zsic(
    W: torch.Tensor,
    H: torch.Tensor,
    *,
    cfg: ZSICConfig,
    Sig_X: torch.Tensor | None = None,
    Sig_hX: torch.Tensor | None = None,
    Sig_X_hX: torch.Tensor | None = None,
    Sig_delta_R_Xhat: torch.Tensor | None = None,  # Residual compensation
) -> Tuple[torch.Tensor, float, float, Dict[str, object]]:
    """ZSIC compression with optional Qronos mode - wrapper for pipeline.

    Args:
        W: Weight matrix
        H: Hessian (typically E[X X^T])
        cfg: ZSIC configuration
        Sig_X: E[X X^T] - unquantized activations covariance (optional, uses H if not provided)
        Sig_hX: E[X̂ X̂^T] - quantized activations covariance (for Qronos)
        Sig_X_hX: E[X X̂^T] - cross-covariance (for Qronos)
        Sig_delta_R_Xhat: E[(R - R̂) X̂^T] - residual compensation term (for wo/w2 layers)
    """
    quiet = False
    qronos_mode = cfg.qronos and Sig_hX is not None and Sig_X_hX is not None
    residual_comp_mode = cfg.residual_compensation and Sig_delta_R_Xhat is not None
    dtype = W.dtype

    # Use H as Sig_X if Sig_X not provided (H is typically E[X X^T])
    Sig_X_work = Sig_X if Sig_X is not None else H

    print(f"[compress_zsic] binary_search={cfg.binary_search}, qronos={qronos_mode}, "
          f"residual_comp={residual_comp_mode}, "
          f"Sig_X={'provided' if Sig_X is not None else 'using H'}, "
          f"Sig_hX={'provided' if Sig_hX is not None else 'None'}, "
          f"Sig_X_hX={'provided' if Sig_X_hX is not None else 'None'}, "
          f"Sig_delta_R_Xhat={'provided' if Sig_delta_R_Xhat is not None else 'None'}", flush=True)

    if cfg.binary_search:
        print(f"[compress_zsic] calling compress_zsic_with_binary_search with desired_rate={cfg.target_rate_bits}", flush=True)
        return compress_zsic_with_binary_search(
            W, H, cfg=cfg, desired_rate=cfg.target_rate_bits,
            Sig_X=Sig_X_work, Sig_hX=Sig_hX, Sig_X_hX=Sig_X_hX,
            Sig_delta_R_Xhat=Sig_delta_R_Xhat if residual_comp_mode else None,
        )

    # Direct compression without binary search
    # compress_w2q handles both qronos and non-qronos cases
    # Residual compensation can work with or without Qronos
    final_loss, final_rate, What, frame_locals = compress_w2q(
        W, Sig_X_work, target_rate=cfg.target_rate_bits, quiet=quiet,
        Sig_hX=Sig_hX if qronos_mode else None,
        Sig_X_hX=Sig_X_hX if qronos_mode else None,
        percdamp=cfg.percdamp,
        Sig_delta_R_Xhat=Sig_delta_R_Xhat if cfg.residual_compensation and Sig_delta_R_Xhat is not None else None,
    )

    # Build frame for pipeline
    frame = _build_frame_from_locals(frame_locals, cfg, qronos_mode, residual_comp_mode)
    return What.to(dtype), float(final_loss), float(final_rate), frame


def _build_frame_from_locals(
    frame_locals: dict, cfg: ZSICConfig, qronos: bool, residual_comp: bool = False
) -> Dict[str, object]:
    """Build frame dict from compress_w2q locals()."""
    Zsic = frame_locals.get('Zsic')
    alphas = frame_locals.get('alphas')
    gammas = frame_locals.get('gammas')
    T = frame_locals.get('T')
    Gamma = frame_locals.get('Gamma')
    entropy = frame_locals.get('entropy')
    final_loss = frame_locals.get('final_loss')
    final_rate = frame_locals.get('final_rate')
    a = frame_locals.get('a')
    n = frame_locals.get('n')

    t_vec = T.diag() if T is not None else None
    g_vec = Gamma.diag() if Gamma is not None else None

    frame = {
        "Z": Zsic,
        "alpha": (alphas * gammas) if alphas is not None and gammas is not None else None,
        "alpha_base": alphas,
        "zero_point": None,
        "apply_tgamma": True,  # compress_w2q always applies tgamma
        "t_vec": t_vec,
        "g_vec": g_vec,
        "sic_variant": "compress_w2q",
        "target_rate_bits": cfg.target_rate_bits,
        "entropy": entropy.item() if hasattr(entropy, 'item') else float(entropy),
        "rate_overhead": 16 / a + 16 / n if a and n else 0,
        "loss": float(final_loss) if final_loss is not None else None,
        "qronos": qronos,
        "residual_compensation": residual_comp,
    }
    return frame


@torch.no_grad()
def _fast_rate_estimate(W: torch.Tensor, H: torch.Tensor, target_rate: float,
                        row_fraction: float = 0.1, percdamp: float = 0.0001,
                        Sig_hX: torch.Tensor = None, Sig_X_hX: torch.Tensor = None,
                        Sig_delta_R_Xhat: torch.Tensor = None) -> float:
    """Fast rate estimation using a subset of rows."""
    a, n = W.shape

    if row_fraction < 1.0:
        n_rows = max(1, int(a * row_fraction))
        indices = torch.randperm(a, device=W.device)[:n_rows]
        W_sampled = W[indices]
        Sig_delta_R_Xhat_sampled = Sig_delta_R_Xhat[indices] if Sig_delta_R_Xhat is not None else None
    else:
        W_sampled = W
        Sig_delta_R_Xhat_sampled = Sig_delta_R_Xhat

    # Use the appropriate Hessian
    if Sig_hX is not None:
        H_work = Sig_hX.double()
    else:
        H_work = H.double()

    # Add damping to Hessian diagonal
    damp = percdamp * torch.mean(torch.diag(H_work))
    H_damped = H_work + damp * torch.eye(n, device=H_work.device, dtype=H_work.dtype)
    L = torch.linalg.cholesky(H_damped, upper=False)

    if Sig_hX is not None and Sig_X_hX is not None:
        # Qronos: Y = (W @ Σ_{X,X̂} + Σ_{ΔR,X̂}) @ L̂^{-T}
        target = W_sampled.double() @ Sig_X_hX.double()
        if Sig_delta_R_Xhat_sampled is not None:
            target = target + Sig_delta_R_Xhat_sampled.double()
        Ycur = torch.linalg.solve_triangular(L.T, target, left=False, upper=True)
    else:
        # LDLQ: Y = W @ L + Σ_{ΔR,X̂} @ L^{-T}
        Ycur = W_sampled.double() @ L
        if Sig_delta_R_Xhat_sampled is not None:
            Ycur = Ycur + torch.linalg.solve_triangular(
                L, Sig_delta_R_Xhat_sampled.double().T, left=True, upper=False
            ).T

    Sw = W_sampled.T @ W_sampled / W_sampled.shape[0]
    target_rate_nats = target_rate * math.log(2)
    c_param = torch.exp(torch.log(12 * Sw.diag().double() * (L.diag() ** 2)).mean() / 2 - target_rate_nats)

    alphas = c_param / L.diag()
    gammas = torch.ones(n, device=W.device, dtype=torch.double)

    Zsic = torch.zeros((W_sampled.shape[0], n), dtype=torch.int64, device=W.device)

    for col in range(n - 1, -1, -1):
        wcol = Ycur[:, col]
        zcol = torch.round(wcol / c_param).long()
        Zsic[:, col] = zcol
        f1 = (zcol.double() * wcol).sum()
        f2 = (zcol.double() ** 2).sum()
        if f2 > 0:
            gammas[col] = f1 / f2 / c_param
            corrector = torch.outer(zcol.double(), gammas[col] * alphas[col] * L[col, :])
            Ycur = Ycur - corrector
        else:
            gammas[col] = 0

    # Compute entropy
    zsic_elts, zsic_counts = torch.unique(Zsic.flatten(), return_counts=True)
    probs = zsic_counts.float() / Zsic.numel()
    entropy = -torch.sum(probs * torch.log2(probs))

    return entropy.item()


@torch.no_grad()
def compress_zsic_with_binary_search(
    W: torch.Tensor,
    H: torch.Tensor,
    *,
    cfg: ZSICConfig,
    desired_rate: float,
    Sig_X: torch.Tensor | None = None,
    Sig_hX: torch.Tensor | None = None,
    Sig_X_hX: torch.Tensor | None = None,
    Sig_delta_R_Xhat: torch.Tensor | None = None,  # Residual compensation
) -> Tuple[torch.Tensor, float, float, Dict[str, object]]:
    """ZSIC with binary search for target rate."""
    n_iters = cfg.binary_search_iters
    qronos_mode = cfg.qronos and Sig_hX is not None and Sig_X_hX is not None
    residual_comp_mode = cfg.residual_compensation and Sig_delta_R_Xhat is not None
    dtype = W.dtype

    # Use H as Sig_X if not provided
    Sig_X_work = Sig_X if Sig_X is not None else H

    left, right = cfg.binary_search_left, cfg.binary_search_right

    best_target, best_diff = None, float('inf')
    for i in range(n_iters):
        mid = (left + right) / 2.0

        # Fast rate estimate (includes residual compensation if enabled)
        if qronos_mode:
            actual_rate = _fast_rate_estimate(W, H, mid, row_fraction=cfg.binary_search_row_fraction,
                                               percdamp=cfg.percdamp, Sig_hX=Sig_hX, Sig_X_hX=Sig_X_hX,
                                               Sig_delta_R_Xhat=Sig_delta_R_Xhat if residual_comp_mode else None)
        else:
            actual_rate = _fast_rate_estimate(W, H, mid, row_fraction=cfg.binary_search_row_fraction,
                                               percdamp=cfg.percdamp,
                                               Sig_delta_R_Xhat=Sig_delta_R_Xhat if residual_comp_mode else None)

        diff = abs(actual_rate - desired_rate)

        if diff < best_diff:
            best_target, best_diff = mid, diff

        if actual_rate > desired_rate:
            right = mid
        else:
            left = mid

        if i == 0 or i == n_iters - 1:
            print(f"[binary-search] iter {i+1}/{n_iters}: target={mid:.4f} actual_rate={actual_rate:.4f} desired={desired_rate:.4f}", flush=True)

    print(f"[binary-search] done, best_target={best_target:.4f} best_diff={best_diff:.4f}", flush=True)

    # Run full compression with best target
    # compress_w2q handles both qronos and non-qronos cases
    print(f"[zsic] starting core compression (qronos={qronos_mode}, residual_comp={residual_comp_mode})", flush=True)

    # Residual compensation can work with or without Qronos
    final_loss, final_rate, What, frame_locals = compress_w2q(
        W, Sig_X_work, target_rate=best_target, quiet=False,
        Sig_hX=Sig_hX if qronos_mode else None,
        Sig_X_hX=Sig_X_hX if qronos_mode else None,
        percdamp=cfg.percdamp,
        Sig_delta_R_Xhat=Sig_delta_R_Xhat if cfg.residual_compensation and Sig_delta_R_Xhat is not None else None,
    )

    frame = _build_frame_from_locals(frame_locals, cfg, qronos_mode, residual_comp_mode)
    frame["binary_search_iterations"] = n_iters
    frame["binary_search_target_used"] = best_target
    frame["binary_search_desired"] = desired_rate
    frame["binary_search_final_diff"] = best_diff

    return What.to(dtype), float(final_loss), float(final_rate), frame
# =============================================================================
# Dequantization
# =============================================================================

@torch.no_grad()
def sic_decode(Z: torch.Tensor, alpha: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor:
    out = Z.to(alpha.dtype) * alpha
    return out + zero_point.to(alpha.dtype) if zero_point is not None else out


@torch.no_grad()
def dequantize_zsic(
    Z: torch.Tensor,
    alpha: torch.Tensor,
    *,
    alpha_base: torch.Tensor | None = None,
    zero_point: torch.Tensor | None = None,
    apply_tgamma: bool = False,
    t_vec: torch.Tensor | None = None,
    g_vec: torch.Tensor | None = None,
    dtype: torch.dtype,
) -> torch.Tensor:
    if apply_tgamma:
        if alpha_base is None or t_vec is None or g_vec is None:
            raise ValueError("alpha_base, t_vec, g_vec required when apply_tgamma=True")
        W_hat = sic_decode(Z, alpha_base, zero_point=zero_point)
        W_hat = (t_vec.to(W_hat.dtype).unsqueeze(1) * W_hat) * g_vec.to(W_hat.dtype).unsqueeze(0)
    else:
        W_hat = sic_decode(Z, alpha, zero_point=zero_point)
    return W_hat.to(dtype)
