"""
Measure Bilevel FLOPs (theoretical vs measured with torch.profiler)
and compare to HALS theoretical FLOPs.

Bilevel FLOPs = m^2 + K*m +m
HALS FLOPs    = r*(2m^2 + 4mr + m)


"""

"""


@author: Anonymous

"""

import torch
import matplotlib.pyplot as plt


def fast_extended_hoyer_projection_with_loops(y: torch.Tensor,
                                              l: float,
                                              eps: float = 1e-12,
                                              max_iter: int = 50):
    """
    One-shot style extended Hoyer projection for a 1D tensor.
    Returns (x, loops) where loops is the number of active-set stabilizing while-loops.
    (Implementation is complex, derived from the original uploaded file.)
    """
    assert y.dim() == 1, "Expected 1D tensor"
    n = y.numel()
    sgn = torch.sign(y)
    x = y.abs().clone()

    nu_prev = int((x > 0).sum().item())
    nu = nu_prev + 1
    loops = 0
    alpha = torch.tensor(0.0, device=y.device, dtype=torch.get_default_dtype())
    while nu != nu_prev and loops < max_iter:
        loops += 1
        nu_prev = nu
        mask = (x > 0)
        nu = int(mask.sum().item())
        if nu == 0:
            return torch.zeros_like(y), loops
        x_active = x[mask]
        l1_active = x_active.sum()
        l2_active = torch.sqrt((x_active * x_active).sum().clamp_min(eps))
        Hx = (l1_active / l2_active).pow(2).clamp_min(1.0).clamp_max(float(n))
        num = l * (nu - Hx)
        den = Hx * (nu - l + eps)
        frac = (num / den).clamp_min(0.0)
        root = torch.sqrt(frac)
        alpha = (l1_active / max(nu, 1)) * (1.0 - root)
        x = torch.where(x >= alpha, x, torch.zeros_like(x))

    mask = (x > 0)
    nu = int(mask.sum().item())
    if nu == 0:
        return torch.zeros_like(y), loops
    x_active = x[mask]
    l1_active = x_active.sum()
    denom = (1.0 - (alpha * nu) / (l1_active + eps))
    lam = 1.0 / denom if abs(float(denom)) > eps else torch.tensor(1.0, device=y.device, dtype=torch.get_default_dtype())
    d = torch.zeros_like(x)
    d_val = (l1_active / max(nu, 1))
    d[mask] = d_val
    x = lam * x + (1.0 - lam) * d

    x = x * sgn
    xy = (x * y).sum()
    xx = (x * x).sum().clamp_min(eps)
    scale = (xy / xx) if xx > 0 else torch.tensor(0.0, device=y.device, dtype=torch.get_default_dtype())
    x = x * scale
    return x, loops





def proj_Hoyer(w0, level, device="cpu"):
    """Wrapper for Hoyer projection, handles flattening and reshaping."""
    w = torch.as_tensor(w0, dtype=torch.get_default_dtype(), device=device)
    init_shape = w.size()

    if w.dim() > 1:
        w = w.reshape(-1)

    x, _loops = fast_extended_hoyer_projection_with_loops(w, float(level))

    Q = x.reshape(init_shape).clone().detach()
    if not torch.is_tensor(w0):
        Q = Q.data.numpy()
    return Q


def bilevel_proj_HoyerInftyball(w2, level, device="cpu"):
    device=w2.device
    w = torch.as_tensor(w2, dtype=torch.get_default_dtype(), device=device)

    if w.dim() == 1:
        Q = proj_Hoyer(w, level, device=device)
    else:

        init_shape = w.shape
        Res = torch.empty(init_shape,device=device)
        nrow, ncol = init_shape[0:2]

        W = torch.tensor(
            [torch.max(torch.abs(w[:, i])).data.item() for i in range(ncol)]
        ,device=device)

        PW = proj_Hoyer(W, level, device=device)

        for i in range(ncol):
            Res[:, i] = torch.clamp(torch.abs(w[:, i]), max=PW[i].data.item())
            Res[:, i] = Res[:, i].to(device) * torch.sign(w[:, i]).to(device)

        Q = Res.clone().detach().requires_grad_(True)

    if not torch.is_tensor(w2):
        Q = Q.data.numpy()
    

    return Q

def measure_flops_once(m, l=100):
    """
    Measure FLOPs for one bilevel projection on an m x m random matrix.
    Returns: total FLOPs measured.
    """
    X = torch.rand(m, m)

    try:
        from torch.profiler import profile, ProfilerActivity
    except Exception:
        print("Torch profiler not available. Returning 0.")
        return 0

    activities = [ProfilerActivity.CPU]
    if torch.cuda.is_available():
        X = X.to("cuda")
        activities.append(ProfilerActivity.CUDA)

    with profile(activities=activities, with_flops=True) as prof:
        with torch.no_grad():
            _ = bilevel_proj_HoyerInftyball(X, l)

    total_flops = 0
    for evt in prof.key_averages():
        if hasattr(evt, "flops") and evt.flops is not None:
            total_flops += int(evt.flops)

    return total_flops


def theoretical_bilevel(m):
    """Theoretical FLOPs for bilevel: m^2 + 5m"""
    return m * m + 5 * m


def theoretical_hals(m, r):
    """Theoretical HALS: r(2m^2 + 4mr + m)"""
    return r * (2*m*m + 4*m*r + m)


def main():

    ms = [1000, 2000, 3000, 4000, 5000]
    l = 200# used both for Hoyer level and rank=r

    measured = []
    bilevel_theo = []
    hals_theo = []

    print("\n=== Measuring Bilevel FLOPs (and HALS theoretical) ===\n")
    print(f"{'m':>8} {'Measured Bilevel FLOPs':>25} {'Theo Bilevel FLOPs':>25} {'Theo HALS FLOPs':>25}")
    print("-" * 90)

    for m in ms:
        meas = measure_flops_once(m, l)
        theo_b = theoretical_bilevel(m)
        theo_h = theoretical_hals(m, r=l)

        measured.append(meas)
        bilevel_theo.append(theo_b)
        hals_theo.append(theo_h)

        print(f"{m:8d} {meas / 1e9:25,.4f} {theo_b / 1e9:25,.4f} {theo_h / 1e9:25,.4f}")

    # -------- Plot ---------
    # Convert to GFLOPs
    bilevel_theo_gflops = [x / 1e9 for x in bilevel_theo]
    measured_gflops = [x / 1e9 for x in measured]
    hals_theo_gflops = [x / 1e9 for x in hals_theo]
    plt.figure(figsize=(9,7))
    plt.plot(ms, bilevel_theo_gflops, marker='o', label="Bilevel Theoretical (m² + 5m)")
    plt.plot(ms, measured_gflops, marker='s', label="Bilevel Measured (Profiler)")
    plt.plot(ms, hals_theo_gflops, marker='^', label=f"HALS Theoretical (r = {l})")

    plt.xlabel("Matrix size m", fontsize=12)
    plt.ylabel("GFLOPs", fontsize=12)
    plt.title("Bilevel Projection FLOPs vs HALS (Theoretical vs Measured)", fontsize=14)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.legend()
    plt.tight_layout()
    plt.savefig("bilevel_hals_flops_comparison.png", dpi=220)
    plt.show()

    print("\nPlot saved as: bilevel_hals_flops_comparison.png")


    plt.figure()
    plt.plot(ms, bilevel_theo_gflops, marker='o', label="Theoretical FLOPs (m² + 5m)")
    plt.plot(ms, measured_gflops, marker='s', label="Measured FLOPs (Profiler)")

    plt.xlabel("Matrix size m")
    plt.ylabel("GFLOPs")
    plt.title("Bilevel Projection FLOPs — Theoretical vs Measured")
    plt.grid(True, linestyle='--')
    plt.legend()
    plt.tight_layout()
    plt.savefig("bilevel_flops_comparison.png", dpi=200)
    plt.show()

    print("\nPlot saved as: bilevel_flops_comparison.png\n")

if __name__ == "__main__":
    main()
