import numpy as np  
import torch

def dense_from_sparse(idx: torch.Tensor, val: torch.Tensor, size: int) -> torch.Tensor:
    out = torch.zeros(size, device=val.device, dtype=val.dtype)
    if idx.numel() > 0:
        out[idx.long()] = val
    return out


def alpha_policy(
    eta_r, T, delta, L_est=1,
    controller=None,  
    rho_tgt=0.3     
):
    s = float(eta_r) * float(L_est) * float(T)
    gamma = max(0.0, 1.0 - 1.0 / float(delta))  # (0,1) for biased compression

    # theory minimizer
    a_star = 1.0 / (1.0 + 12.0 * s * s)

    A = 2.0 + 24.0 * s * s
    C = 2.0 - (rho_tgt / gamma)
    disc = max(0.0, 16.0 - 4.0 * A * C)
    a_rho = (4.0 - np.sqrt(disc)) / (2.0 * A)
    a_nom = max(a_star, a_rho)

    # strong floor under heavy compression
    a = float(np.clip(a_nom, 0.8, 0.95))

    rho = gamma * (2.0 * (1.0 - a) ** 2 + 24.0 * (a ** 2) * (s ** 2))
    return a, s, rho

# -------------------------
# Cosine schedules
# -------------------------
def cosine_eta(r, R, eta_max, eta_min=0.0):
    """Cosine decay from eta_max at r=0 to eta_min at r=R-1."""
    if R <= 1:
        return eta_max
    phase = np.clip(r / (R - 1), 0.0, 1.0)
    return eta_min + 0.5 * (eta_max - eta_min) * (1.0 + np.cos(np.pi * phase))
import numpy as np

def alpha_with_contraction_floor_updated(
    r, R, eta_r, L_est, T, delta,
    *,
    taper="cosine",          # {"cosine","linear",None}
    kill_ratio=0.1,         #  0.1 -> turn off step-ahead when ||e||/||e0|| < 0.1
    e_norm=None, e0_norm=None,
):
    # local scale
    s = float(eta_r) * float(L_est) * float(T)

    # γ = 1 - 1/δ; handle δ=1 (unbiased compressor) separately
    if delta <= 1.0:
        gamma = 0.0
    else:
        gamma = 1.0 - 1.0 / float(delta)  # in (0,1)

    # ---- contraction floor α_min from ρ_r = γ[2 - 4α + (2+24 s^2)α^2] ≤ 1 ----
    # Solve: A α^2 - 4 α + (2 - 1/γ) ≤ 0, with A := 2 + 24 s^2, c := 1/γ
    A = 2.0 + 24.0 * (s * s)
    if gamma == 0.0:
        # δ=1 ⇒ ρ_r=0 always; no floor needed
        alpha_min = 0.0
    else:
        c = 1.0 / gamma
        disc = 16.0 - 4.0 * A * (2.0 - c)   # Δ = b^2 - 4ac, b = -4
        # In our regime (s ≤ 1/8, δ≥1), disc should be > 0; still guard for safety:
        if disc > 0.0:
            lower = (4.0 - np.sqrt(disc)) / (2.0 * A)
            alpha_min = float(np.clip(lower, 0.0, 1.0))
        else:
            # If no real roots, contraction is not restricting α in practice for δ≈1,
            # so choose the most conservative floor 0.0.
            alpha_min = 0.0

    # ---- theory minimizer for residual contraction ----
    alpha_star = 1.0 / (1.0 + 12.0 * (s * s))

    # ---- taper toward α_min to avoid late noise ----
    if taper == "cosine":
        w = 0.5 * (1.0 + np.cos(np.pi * r / max(1, R)))
    elif taper == "linear":
        w = max(0.0, 1.0 - (r / max(1.0, float(R))))
    else:
        w = 1.0  # no taper
    alpha = alpha_min + (alpha_star - alpha_min) * w
    alpha = float(np.clip(alpha, alpha_min, 1.0))

    # ---- optional residual-trigger kill switch ----
    if (kill_ratio is not None) and (e_norm is not None) and (e0_norm is not None) and (e0_norm > 0):
        if (e_norm / e0_norm) < kill_ratio:
            alpha = 0.0  # switch to EF when residual is tiny

    # implied contraction
    rho = gamma * (2.0 * (1.0 - alpha) ** 2 + 24.0 * (alpha ** 2) * (s ** 2))
    return alpha, s, rho, alpha_min, alpha_star

def alpha_with_contraction_floor(r, R, eta_r, L, T, delta):
    # local scale
    s = eta_r * L * T

    # handle edge cases for delta
    one_minus_inv_delta = 1.0 - (1.0 / delta) if delta > 0 else 0.0
    if delta <= 1.0:
        # No contraction penalty — any alpha works; put floor at 0
        alpha_min = 0.0
    else:
        # Solve: (1 - 1/delta) * (2 - 4a + (2 + 24 s^2) a^2) <= 1
        # => A a^2 - 4 a + (2 - c) <= 0, where:
        A = 2.0 + 24.0 * (s * s)
        c = 1.0 / (1.0 - 1.0 / delta)  # > 1
        disc = 16.0 - 4.0 * A * (2.0 - c)  # discriminant
        if disc <= 0.0:
            # No real root region ⇒ choose conservative alpha=1 (safe but may be overkill)
            alpha_min = 1.0
        else:
            sqrt_disc = np.sqrt(disc)
            # interval between roots is feasible; take lower root as floor
            lower = (4.0 - sqrt_disc) / (2.0 * A)
            # ensure within [0,1]
            alpha_min = float(np.clip(lower, 0.0, 1.0))

    # Theory minimizer (for the residual contraction)
    alpha_star = 1.0 / (1.0 + 12.0 * (s * s))

    # Cosine blend toward alpha_star (starts at alpha_star and smoothly “relaxes”)
    blend = 0.5 * (1.0 + np.cos(np.pi * r / R))
    alpha = alpha_min + (alpha_star - alpha_min) * blend
    # Final clamp to [alpha_min, 1]
    alpha = float(np.clip(alpha, alpha_min, 1.0))

    # For logging: implied contraction
    rho = one_minus_inv_delta * (2.0 * (1.0 - alpha) ** 2 + 24.0 * (alpha ** 2) * (s ** 2))
    return alpha, s, rho

def check_descent_coupling(eta_server, eta_r, L, T, beta):
    s = eta_r * L * T
    ok_s     = (s <= 1.0/8.0)
    ok_quad  = (18.0 * (beta ** 2) * (s ** 2) <= 1.0/8.0)
    ok_etaSv = (eta_server <= (1.0 / (256.0 * (beta ** 2) * L * eta_r * T)))
    return {"s<=1/8": ok_s, "18β^2 s^2<=1/8": ok_quad, "eta_server_coupling": ok_etaSv}