#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

@author: Anonymous
"""


import time
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
import math
try:
    from torch.profiler import profile, ProfilerActivity
    _HAS_TORCH_PROFILER = True
except Exception:
    _HAS_TORCH_PROFILER = False

EPS = 1e-10

def l2(x):
    return torch.sqrt((x * x).sum() + EPS)

def hoyer_H(v):
    l1 = v.abs().sum()
    l2n = l2(v)
    return float((l1 / l2n) ** 2)





# ----------------------------------------------------------------
# PROJECTION SIMPLEX (Condat 2016) — 
#   x = argmin ||x - v||_2  s.t. x >= 0, sum(x) = s
#   u = sort(v, desc), cs = cumsum(u)
#   rho = max { j : u_j > (cs_j - s)/j }
#   tau = (cs_rho - s)/rho
#   x = max(v - tau, 0)
# ----------------------------------------------------------------
def proj_simplex_condat_sorted(v: torch.Tensor, s: float) -> torch.Tensor:
    assert v.dim() == 1, "v must be 1D"
    device, dtype = v.device, v.dtype
    n = v.numel()
    if n == 0:
        return v
    if s <= 0:
        return torch.zeros_like(v)

    
    u, _ = torch.sort(v, descending=True)
    cssv = torch.cumsum(u, dim=0) - s
    j = torch.arange(1, n + 1, device=device, dtype=dtype)
    # 
    cond = u > (cssv / j)
    if not torch.any(cond):
        # 
        return torch.zeros_like(v)

    rho = torch.nonzero(cond, as_tuple=False)[-1].item() + 1  # index 1-based
    tau = cssv[rho - 1] / rho
    w = torch.clamp(v - tau, min=0.0)

    # 
    sw = w.sum()
    if sw > 0:
        w = w * (s / sw)
    return w


# ===============================================================
# Hoyer 'projfunc' (2004) — PyTorch port with Condat simplex init (from matlab)
# ===============================================================
def projfunc_torch(s: torch.Tensor, k1: float, k2: float, nn: bool):
    """
    PyTorch port of Patrik O. Hoyer's 'projfunc' (2004),
    modified with Condat simplex projection as initialization:

        v = proj_simplex_condat_sorted(s, k1)

    Constraints:
        sum(abs(v)) = k1
        sum(v^2)    = k2
        nn=True  -> v >= 0 (non-negativity)

    Returns:
        v (tensor), usediters (int)
    """
    assert s.dim() == 1, "s must be a 1D tensor"
    device = s.device
    dtype  = s.dtype
    N      = s.numel()

    # If non-negativity flag not set, record signs and take abs
    if not nn:
        isneg = s < 0
        s = s.abs()
    else:
        isneg = None

    # --- Initialisation modifiée : projection directe sur le simplexe ---
    v = proj_simplex_condat_sorted(s, float(k1))

    zerocoeff = torch.zeros(N, dtype=torch.bool, device=device)  # 
    j = 0

    while True:
        # Active set
        active_count = int((~zerocoeff).sum().item())
        if active_count <= 0:
            # Cas dégénéré
            v = torch.zeros_like(v)
            if N > 0:
                idx = torch.argmax(s.abs())
                v[idx] = k1
            usediters = j + 1
            break

        midpoint_val = k1 / active_count
        midpoint = torch.where(
            zerocoeff,
            torch.zeros((), dtype=dtype, device=device),
            torch.tensor(midpoint_val, dtype=dtype, device=device)
        )

        w = v - midpoint

        a = (w * w).sum()
        b = 2.0 * (w * v).sum()
        c = (v * v).sum() - k2

        # alphap = (-b + sqrt(b^2 - 4ac)) / (2a)
        if float(a.abs()) <= 1e-20:
            alphap = torch.zeros((), dtype=dtype, device=device)
        else:
            disc = (b * b - 4.0 * a * c).clamp_min(0.0)
            sqrt_disc = torch.sqrt(disc)
            alphap = (-b + sqrt_disc) / (2.0 * a + torch.finfo(dtype).eps)

        v = alphap * w + v  # update

        # Si tout est >= 0, fini
        if torch.all(v >= 0):
            usediters = j + 1
            break

        # 
        j += 1
        neg_mask = v <= 0
        zerocoeff = zerocoeff | neg_mask
        v = torch.where(neg_mask, torch.zeros_like(v), v)

        active_count = int((~zerocoeff).sum().item())
        if active_count <= 0:
            v = torch.zeros_like(v)
            if N > 0:
                idx = torch.argmax(s.abs())
                v[idx] = k1
            usediters = j
            break


        vp = proj_simplex_condat_sorted(v, float(k1))
        shift = vp / active_count
        
        #tempsum = v.sum()
        #shift = (k1 - tempsum) / active_count
        v = vp + shift
        v = torch.where(zerocoeff, torch.zeros_like(v), v)

    # Restauration des signes si nn=False
    if not nn:
        sign = torch.where(
            isneg,
            torch.tensor(-1.0, dtype=dtype, device=device),
            torch.tensor(1.0,  dtype=dtype, device=device),
        )
        v = v * sign
        
          
             

    return v, usediters


# ----------------------------------------------------------------
# ORIGINAL HOYER:
# ----------------------------------------------------------------
def hoyer_original_projection(y: torch.Tensor, l: float,
                              tol: float = 1e-7, max_iter: int = 200,
                              restore_sign: bool = True):
    """Wrapper that delegates to Hoyer's 'projfunc_torch' (Matlab original)
    to match 'Bench-original-true.py'. We compute k1 = sqrt(l)*||y||_2 and
    k2 = ||y||_2^2, with nn=False to preserve/restore signs like the Matlab code.

    Returns: (x, usediters)
    """
    assert y.dim() == 1
    # Parameters per Bench-original-true.py
    k2 = float((y * y).sum().item())
    k1 = float(l**0.5) * float(l2(y).item())
    x, usediters = projfunc_torch(y.clone(), k1=k1, k2=k2, nn=False)
    return x, usediters

# ----------------------------------------------------------------
# Iterative
# ----------------------------------------------------------------
def naive_with_iters(y: torch.Tensor, l: float, max_iter: int = 200):
    assert y.dim() == 1
    device, dtype = y.device, y.dtype
    n = y.numel()
    sgn = torch.sign(y)
    x = y.abs().clone()
    iters = 0

    l_eff = float(max(1.0, min(float(l), float(n))))

    for k in range(max_iter):
        iters = k + 1
        mask = (x >= 0)
        nu = int(mask.sum().item())
        if nu == 0:
            return torch.zeros_like(y), iters

        x_active = x[mask]
        l1a = x_active.sum()
        l2a_sq = (x_active * x_active).sum() + EPS
        l2a = torch.sqrt(l2a_sq)
        Hx = (l1a / l2a).pow(2).clamp_min(1.0).clamp_max(float(n))

        if float(Hx) <= l_eff + 1e-10:
            break

        num = Hx * (nu - l_eff)
        den = l_eff * (nu - Hx + EPS)
        frac = torch.clamp(num / den, min=0.0)
        lam = torch.sqrt(frac)

        d = torch.zeros_like(x)
        d_val = (l1a / max(nu, 1))
        d[mask] = d_val

        x = lam * x + (1.0 - lam) * d
        x = torch.clamp(x, min=0.0)

    x = x * sgn
    xy = (x * y).sum()
    xx = (x * x).sum() + EPS
    scale = xy / xx
    x = x * scale
    return x, iters

# ----------------------------------------------------------------
# FAST Closed-form Projection 
# ----------------------------------------------------------------
def fast_extended_hoyer_projection_with_loops(y: torch.Tensor, l: float, max_iter: int = 50):
    assert y.dim() == 1
    n = y.numel()
    sgn = torch.sign(y)
    x = y.abs().clone()

    l_eff = float(max(1.0, min(float(l), float(n))))

    nu_prev = int((x > 0).sum().item())
    nu = nu_prev + 1
    loops = 0
    alpha = torch.tensor(0.0, device=y.device, dtype=y.dtype)

    while nu != nu_prev and loops < max_iter:
        loops += 1
        nu_prev = nu
        mask = (x > 0)
        nu = int(mask.sum().item())
        if nu == 0:
            return torch.zeros_like(y), loops

        x_active = x[mask]
        
        l1a = x_active.sum()
        
       
        l2a = (x_active * x_active).sum() 
               
        
        Hx = (l1a*l1a / l2a)
        num = l_eff * (nu - Hx)
        den = Hx * (nu - l_eff) + EPS
        frac = (num / den)
        #frac = torch.clamp(num / den, min=0.0)
        root = torch.sqrt(frac)
        
        #alpha = (l1a / max(nu, 1)) * (1.0 - root)
        
        alpha = (l1a / nu*(1.0 - root))
        

        x = torch.where(x >= alpha, x, torch.zeros_like(x))

    mask = (x > 0)
    nu = int(mask.sum().item())
    
    
    
    if nu == 0:
        return torch.zeros_like(y), loops

    x_active = x[mask]
    l1a = x_active.sum()

    denom = (1.0 - (alpha * nu) / (l1a + EPS))
    #lam = 1.0 / denom if abs(float(denom)) > 1e-15 else torch.tensor(1.0, device=y.device, dtype=y.dtype)

    
    lam = 1.0 / denom
    
    d = torch.zeros_like(x)
    d_val = (l1a / max(nu, 1))
    d[mask] = d_val

    x = lam * x + (1.0 - lam) * d

    x = x * sgn
    xy = (x * y).sum()
    xx = (x * x).sum() + EPS
    scale = xy / xx
    x = x * scale
    return x, loops

# ----------------------------------------------------------------
# Benchmark vs FAST (référence)
# ----------------------------------------------------------------
def bench_one_setting(n: int, l: float, rep: int, device: torch.device, profile_mode: bool = False, profile_reps: int = 1):
    t_orig, t_naiv, t_fast = [], [], []
    it_orig, it_naiv, it_fastloops = [], [], []
    diff_orig_vs_fast, diff_naiv_vs_fast = [], []
    rel_orig_vs_fast, rel_naiv_vs_fast = [], []
    # measured flops (from torch.profiler) lists
    meas_flops_orig = []
    meas_flops_naiv = []
    meas_flops_fast = []

    for _ in range(rep):
        y = torch.rand(n, device=device)

        # Fast (référence)
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        # Optionally profile FLOPs for fast method for a limited number of reps
        if profile_mode and len(meas_flops_fast) < profile_reps and _HAS_TORCH_PROFILER:
            try:
                with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] if torch.cuda.is_available() else [ProfilerActivity.CPU], record_shapes=False, with_flops=True) as prof:
                    x_fast, itF = fast_extended_hoyer_projection_with_loops(y.clone(), l)
                # sum flops from events
                total_flops = 0
                for evt in prof.key_averages():
                    fl = getattr(evt, 'flops', None)
                    if fl is not None:
                        total_flops += int(fl)
                meas_flops_fast.append(total_flops)
            except Exception:
                # fallback to no measured flops
                x_fast, itF = fast_extended_hoyer_projection_with_loops(y.clone(), l)
                meas_flops_fast.append(0)
        else:
            x_fast, itF = fast_extended_hoyer_projection_with_loops(y.clone(), l)
        if device.type == "cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
        t_fast.append((t1 - t0) * 1000.0)
        it_fastloops.append(itF)
        norm_fast = torch.norm(x_fast, p=2).item() + EPS

        # Original
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        if profile_mode and len(meas_flops_orig) < profile_reps and _HAS_TORCH_PROFILER:
            try:
                with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] if torch.cuda.is_available() else [ProfilerActivity.CPU], record_shapes=False, with_flops=True) as prof:
                    x_orig, itO = hoyer_original_projection(y.clone(), l)
                total_flops = 0
                for evt in prof.key_averages():
                    fl = getattr(evt, 'flops', None)
                    if fl is not None:
                        total_flops += int(fl)
                meas_flops_orig.append(total_flops)
            except Exception:
                x_orig, itO = hoyer_original_projection(y.clone(), l)
                meas_flops_orig.append(0)
        else:
            x_orig, itO = hoyer_original_projection(y.clone(), l)
        if device.type == "cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
        t_orig.append((t1 - t0) * 1000.0)
        it_orig.append(itO)

        # Naive
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        if profile_mode and len(meas_flops_naiv) < profile_reps and _HAS_TORCH_PROFILER:
            try:
                with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] if torch.cuda.is_available() else [ProfilerActivity.CPU], record_shapes=False, with_flops=True) as prof:
                    x_naiv, itN = naive_with_iters(y.clone(), l)
                total_flops = 0
                for evt in prof.key_averages():
                    fl = getattr(evt, 'flops', None)
                    if fl is not None:
                        total_flops += int(fl)
                meas_flops_naiv.append(total_flops)
            except Exception:
                x_naiv, itN = naive_with_iters(y.clone(), l)
                meas_flops_naiv.append(0)
        else:
            x_naiv, itN = naive_with_iters(y.clone(), l)
        if device.type == "cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
        t_naiv.append((t1 - t0) * 1000.0)
        it_naiv.append(itN)

        # Erreurs vs FAST
        dO = torch.norm(x_orig - x_fast, p=2).item()
        dN = torch.norm(x_naiv - x_fast, p=2).item()
        diff_orig_vs_fast.append(dO)
        diff_naiv_vs_fast.append(dN)
        rel_orig_vs_fast.append(dO / norm_fast)
        rel_naiv_vs_fast.append(dN / norm_fast)

        def estimate_flops(method: str, n: int, iters: float) -> float:
            # Baseline per-iteration cost in FLOPs (add/mul/div etc.)
            if method == 'fast':
                per_iter = 12.0 * n
            elif method == 'naive':
                per_iter = 10.0 * n
            elif method == 'orig':
                # original includes sorting once plus iterative updates
                sort_cost = 5.0 * n * math.log2(max(2, n))
                per_iter = 15.0 * n
                return sort_cost + float(iters) * per_iter
            else:
                per_iter = 10.0 * n
            return float(iters) * per_iter

        flops_fast = estimate_flops('fast', n, itF)
        flops_orig = estimate_flops('orig', n, itO)
        flops_naiv = estimate_flops('naive', n, itN)

        # record FLOPs (raw counts) -- we keep estimates in local variables

    def mean_std(a):
        a = np.asarray(a, dtype=float)
        mu = a.mean() if a.size else 0.0
        sd = a.std(ddof=1) if a.size > 1 else 0.0
        return float(mu), float(sd)

    # compute mean measured flops (if collected)
    def mean_or_zero(a):
        return float(np.mean(a)) if a else 0.0

    out = {
        "time_orig": mean_std(t_orig),
        "time_naiv": mean_std(t_naiv),
        "time_fast": mean_std(t_fast),
        "iters_orig": float(np.mean(it_orig)),
        "iters_naiv": float(np.mean(it_naiv)),
        "loops_fast": float(np.mean(it_fastloops)),
        "diff_orig_vs_fast": float(np.mean(diff_orig_vs_fast)),
        "diff_naiv_vs_fast": float(np.mean(diff_naiv_vs_fast)),
        "rel_orig_vs_fast": float(np.mean(rel_orig_vs_fast)),
        "rel_naiv_vs_fast": float(np.mean(rel_naiv_vs_fast)),
        "flops_orig": float(np.mean([estimate_flops('orig', n, it) for it in it_orig])) if it_orig else 0.0,
        "flops_naiv": float(np.mean([estimate_flops('naive', n, it) for it in it_naiv])) if it_naiv else 0.0,
        "flops_fast": float(np.mean([estimate_flops('fast', n, it) for it in it_fastloops])) if it_fastloops else 0.0,
        # Measured FLOPs from profiler (averaged across profile_reps)
        "meas_flops_orig": mean_or_zero(meas_flops_orig),
        "meas_flops_naiv": mean_or_zero(meas_flops_naiv),
        "meas_flops_fast": mean_or_zero(meas_flops_fast),
    }
    return out

def main():
    parser = argparse.ArgumentParser(description="Benchmark (Original/Naive/Fast) with Condat-simplex (sorted); errors vs Fast.")
    parser.add_argument("--ns", type=int, nargs="+", default=[1000,2000,3000,4000,5000], help="Vector sizes.")
    parser.add_argument("--ls", type=float, nargs="+", default=[ 100,200], help="Hoyer levels l to test.")
    parser.add_argument("--rep", type=int, default=100, help="Repetitions per (n,l).")
    parser.add_argument("--seed", type=int, default=0, help="Random seed.")
    parser.add_argument("--profile", default=True, action='store_true', help="Enable torch.profiler measured FLOPs (slower).")
    parser.add_argument("--profile_reps", type=int, default=1, help="Number of reps to profile per method (default 1).")
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\nns={args.ns}\nls={args.ls}\nrep={args.rep}\n")

    all_results = {}
    for l in args.ls:
        print(f"=== l = {l} ===")
        rows = []
        for n in args.ns:
            res = bench_one_setting(n, l, args.rep, device, profile_mode=args.profile, profile_reps=args.profile_reps)
            rows.append((n, res))
            print(
                f"\nn={n:5d} | "
                f"\nOrig {res['time_orig'][0]:.2f}±{res['time_orig'][1]:.2f} ms | "
                f"\nNaiv {res['time_naiv'][0]:.2f}±{res['time_naiv'][1]:.2f} ms | "
                f"\nFast {res['time_fast'][0]:.2f}±{res['time_fast'][1]:.2f} ms | "
                f"\nIters: Orig={res['iters_orig']:.2f}, Naiv={res['iters_naiv']:.2f}, FastLoops={res['loops_fast']:.2f} | "
                f"\nErr vs Fast: diff(Orig)={res['diff_orig_vs_fast']:.3e}, diff(Naiv)={res['diff_naiv_vs_fast']:.3e}, "
                f"\nrel(Orig)={res['rel_orig_vs_fast']:.3e}, rel(Naiv)={res['rel_naiv_vs_fast']:.3e} | "
                # f"\nEstFlops(M): Orig={res['flops_orig']/1e6:.6f}, Naiv={res['flops_naiv']/1e6:.6f}, Fast={res['flops_fast']/1e6:.6f}"
                f"\nRealFlops(M): Orig={res['meas_flops_orig']/1e6:.6f}, Naiv={res['meas_flops_naiv']/1e6:.6f}, Fast={res['meas_flops_fast']/1e6:.6f}"
            )
        all_results[l] = rows
        print()

    # PLOTS
    plt.figure()
    for l, rows in all_results.items():
        ns = [r[0] for r in rows]
        to = [r[1]["time_orig"][0] for r in rows]
        tn = [r[1]["time_naiv"][0] for r in rows]
        tf = [r[1]["time_fast"][0] for r in rows]
        plt.plot(ns, to, marker="^", label=f"Original (l={int(l)})")
        #plt.plot(ns, tn, marker="s", label=f"Naive (l={int(l)})")
        plt.plot(ns, tf, marker="o", label=f"Fast (l={int(l)})")
    plt.xlabel("Vector length n"); plt.ylabel("Avg time (ms)")
    plt.title("Time vs n (Gaussian Distribution)")
    plt.grid(True, linestyle="--", linewidth=0.5); plt.legend(); plt.tight_layout(); plt.show()

    plt.figure()
    for l, rows in all_results.items():
        ns = [r[0] for r in rows]
        io = [r[1]["iters_orig"] for r in rows]
        inav = [r[1]["iters_naiv"] for r in rows]
        lf = [r[1]["loops_fast"] for r in rows]
        plt.plot(ns, io, marker="^", label=f"Orig iters (l={int(l)})")
        #plt.plot(ns, inav, marker="s", label=f"Naive iters (l={int(l)})")
        plt.plot(ns, lf, marker="o", label=f"Fast loops (l={int(l)})")
    plt.xlabel("Vector length n"); plt.ylabel("Avg iterations / loops")
    plt.title("Iteration and loops counts vs n (Uniform distribution)")
    plt.grid(True, linestyle="--", linewidth=0.5); plt.legend(); plt.tight_layout(); plt.show()

    plt.figure(figsize=(8,6))
    for l, rows in all_results.items():
        ns = [r[0] for r in rows]
        dO = [r[1]["diff_orig_vs_fast"] for r in rows]
        dN = [r[1]["diff_naiv_vs_fast"] for r in rows]
        plt.plot(ns, dO, marker="^", label=f"‖Orig − Fast‖₂ (l={int(l)})")
        #plt.plot(ns, dN, marker="o", label=f"‖Naive − Fast‖₂ (l={int(l)})")
    plt.xlabel("Vector length n"); plt.ylabel("L2 difference")
    plt.title(f"L2 Difference vs Fast (rep={args.rep})")
    plt.grid(True, linestyle="--", linewidth=0.5); plt.legend(); plt.tight_layout(); plt.show()

    plt.figure()
    for l, rows in all_results.items():
        ns = [r[0] for r in rows]
        rO = [r[1]["rel_orig_vs_fast"] for r in rows]
        rN = [r[1]["rel_naiv_vs_fast"] for r in rows]
        plt.plot(ns, rO, marker="^", label=f"RelErr(Original vs Fast) (l={int(l)})")
        plt.plot(ns, rN, marker="o", label=f"RelErr(Naive vs Fast) (l={int(l)})")
    plt.xlabel("Vector length n"); plt.ylabel("Relative $ \ell_2$ error $ \%$")
    plt.title(f"Relative Error  Fast vs Original (rep={args.rep})")
    plt.grid(True, linestyle="--", linewidth=0.5); plt.legend(); plt.tight_layout(); plt.show()

    # Measured FLOPs (from profiler) vs n for each method and l
    plt.figure()
    for l, rows in all_results.items():
        
        ns = [r[0] for r in rows]
        meas_orig = [r[1].get('meas_flops_orig', 0.0) / 1e6 for r in rows]
        # meas_naiv = [r[1].get('meas_flops_naiv', 0.0) / 1e6 for r in rows]
        meas_fast = [r[1].get('meas_flops_fast', 0.0) / 1e6 for r in rows]
        plt.plot(ns, meas_orig, marker='^', linestyle='-', label=f"Measured Orig (l={int(l)})")
        # plt.plot(ns, meas_naiv, marker='s', linestyle='--', label=f"Measured Naiv (l={int(l)})")
        plt.plot(ns, meas_fast, marker='o', linestyle='-.', label=f"Measured Fast (l={int(l)})")
    plt.xlabel('Vector length n')
    plt.ylabel('Measured FLOPs (MFLOPs)')
    plt.gca().yaxis.set_major_locator(plt.MaxNLocator(nbins=20))
    plt.title('Measured FLOPs vs n (Profiler)')
    # plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.4f}'))
    plt.grid(True, linestyle='--', linewidth=0.5)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'measured_flops_vs_n.png')
    plt.show()
    plt.close()

if __name__ == "__main__":
    main()
