#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Oct 15 10:59:22 2025

Time comparison on Matrix Bilevel CFCP vs GSP Hybrid 

Constraint l 

@author: Anonymous
"""

import time
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
#import math
from sklearn.datasets import make_low_rank_matrix

EPS = 1e-10

def l2(x):
    return torch.sqrt((x * x).sum() + EPS)

# ----------------------------------------------------------------
# HYBRID NEWTON-BISECTION METHOD (from GSP)
# ----------------------------------------------------------------
def hybrid_newton_bisection_with_iters(y: torch.Tensor, l: float, precision: float = 1e-6, linrat: float = 0.9, max_iter: int = 100):
    """
    Hybrid Newton-Bisection method for Hoyer sparsity projection.
    
    Uses Newton's method with bisection fallback to solve for the Lagrange multiplier mu
    that enforces the sparsity constraint.
    
    Args:
        y: Input vector (1D tensor)
        l: Target Hoyer sparsity level (1 = maximally sparse, n = dense)
        precision: Convergence tolerance
        linrat: Linear rate guarantee parameter
        max_iter: Maximum iterations
        
    Returns:
        (x, iters): Projected vector and number of iterations
    """
    assert y.dim() == 1
    device, dtype = y.device, y.dtype
    n = float(y.numel())
    
    epsilon = 1e-15
    
    # Handle signs
    sgn = torch.sign(y)
    pos_vector = y.abs()
    
    # Clamp l to valid range [1, n)
    l = max(1.0, min(float(l), n - 1.0))
    
    # Convert Hoyer level l to GSP sparsity parameter sps
    # GSP: sps = (sqrt(n) - L1/L2) / (sqrt(n) - 1), where 0=dense, 1=sparse
    # Hoyer: H = (L1/L2)^2, where 1=sparse, n=dense
    # Target L1/L2 = sqrt(l)
    # Therefore: sps = (sqrt(n) - sqrt(l)) / (sqrt(n) - 1)
    sqrt_n = np.sqrt(n)
    sqrt_l = np.sqrt(float(l))
    sps = (sqrt_n - sqrt_l) / (sqrt_n - 1.0)
    
    # GSP formulation parameters
    betai = 1.0 / (sqrt_n - 1.0)
    r = 1  # Single vector
    
    # Target k for GSP method
    k = r * sqrt_n / (sqrt_n - 1.0) - r * sps
    
    # Check critical values
    max_val = pos_vector.max()
    muup0 = float(max_val * (sqrt_n - 1.0))
    
    # Define gmu function (compute sparsity measure as in GSP)
    def gmu_1d(vec, mu):
        # Project: x = [vec - mu*beta]+
        xp = torch.clamp(vec - mu * betai, min=0.0)
        
        # L2 norm
        norm_xp = torch.norm(xp, p=2)
        
        if norm_xp > epsilon:
            # Normalized vector
            xp_normalized = xp / norm_xp
            nip = (xp > 0).sum().float()
            
            # Gradient computation
            sum_xp_norm = xp_normalized.sum()
            term1 = -nip / norm_xp
            term2 = (xp.sum()) ** 2 / (norm_xp ** 3)
            gradg = (betai ** 2) * (term1 + term2)
            
            # Sparsity measure vgmu = beta * sum(xp_normalized)
            vgmu = betai * sum_xp_norm
        else:
            # If all zeros, set max element to 1
            xp_normalized = torch.zeros_like(xp)
            max_idx = torch.argmax(vec)
            xp_normalized[max_idx] = 1.0
            xp = xp_normalized.clone()
            
            vgmu = betai
            gradg = torch.tensor(0.0, device=device, dtype=dtype)
        
        return vgmu, xp, gradg
    
    # Initial evaluation at mu=0
    vgmu, xp, gradg = gmu_1d(pos_vector, 0.0)
    
    # Check if already at target sparsity
    if vgmu < k:
        # Already sparse enough - no projection needed
        return y.clone(), 0
    
    # Initialize bisection bounds
    mulow = 0.0
    muup = muup0
    glow = vgmu
    
    # Start with mu=0
    newmu = 0.0
    gnew = glow
    gpnew = gradg
    delta = muup - mulow
    
    numiter = 0
    
    # Main iteration loop
    while abs(gnew - k) > precision * r and numiter < max_iter:
        oldmu = newmu
        
        # Newton step
        newmu = oldmu + (k - gnew) / (gpnew + epsilon)
        
        # If Newton goes out of bounds, use bisection
        if newmu >= muup or newmu <= mulow:
            newmu = (mulow + muup) / 2.0
            
        # Evaluate at new mu
        gnew, xnew, gpnew = gmu_1d(pos_vector, newmu)
        
        # Update bounds
        if gnew < k:
            muup = newmu
        else:
            mulow = newmu
            
        # Guarantee linear convergence
        if (muup - mulow) > linrat * delta and abs(oldmu - newmu) < (1 - linrat) * delta:
            newmu = (mulow + muup) / 2.0
            gnew, xnew, gpnew = gmu_1d(pos_vector, newmu)
            
            if gnew < k:
                muup = newmu
            else:
                mulow = newmu
                
            numiter += 1
            
        numiter += 1
    
    # Final projection at converged mu
    _, xp_final, _ = gmu_1d(pos_vector, newmu)
    
    # Normalize the projected vector
    norm_xp_final = torch.norm(xp_final, p=2)
    if norm_xp_final > epsilon:
        xp_normalized = xp_final / norm_xp_final
    else:
        xp_normalized = xp_final
    
    # Scale to match original norm
    orig_norm = torch.norm(pos_vector, p=2)
    xp_scaled = xp_normalized * orig_norm
    
    # Restore signs
    x = xp_scaled * sgn
    
    # Final rescaling using inner product (as in GSP alpha calculation)
    alpha = (x * y).sum() / ((x * x).sum() + epsilon)
    x = alpha * x
    
    return x, numiter

# ----------------------------------------------------------------
# CFCP
# ----------------------------------------------------------------
def CFCP(y: torch.Tensor, l: float, max_iter: int = 4, track_alpha: bool = False, track_active: bool = False):
    assert y.dim() == 1
    n = y.numel()
    sgn = torch.sign(y)
    x = y.abs().clone()
    
    l1 = x.sum()
    alpha = (l1 / n)
    
    #alpha = torch.tensor(1.0, device=y.device, dtype=y.dtype)
    
    x = torch.where(x >= alpha, x, torch.zeros_like(x))
    l_eff = float(max(1.0, min(float(l), float(n))))

    nu_prev = int((x > 0).sum().item())
    nu = nu_prev + 1
    loops = 0
    
    
    # Track alpha values if requested
    alpha_history = [] if track_alpha else None
    active_history = [] if track_active else None

    #while nu != nu_prev and loops < max_iter:
    while loops < max_iter:    
        loops += 1
        nu_prev = nu
        mask = (x > 0)
        nu = int(mask.sum().item())
        #nu=torch.count_nonzero(x)
        if track_active:
            active_history.append(nu)
        if nu == 0:
            return torch.zeros_like(y), loops

        x_active = x

        l1a = x_active.sum()
        
        l2a = (x_active * x_active).sum()
        
       
        Hx = (l1a*l1a / l2a)
        num = l_eff * (nu - Hx)
        den = Hx * (nu - l_eff) + EPS
        frac = (num / den)
        root = torch.sqrt(frac)

        alpha = (l1a / nu*(1.0 - root))
        #alpha = (l1a / nu)
        
        # Track alpha if requested
        if track_alpha:
            alpha_history.append(float(alpha.item()))

        x = torch.where(x >= alpha, x, torch.zeros_like(x))

    mask = (x > 0)
    nu = int(mask.sum().item())
    #nu=torch.count_nonzero(x)
    if track_active and (len(active_history) == 0 or active_history[-1] != nu):
        active_history.append(nu)
    
    
    
    if nu == 0:
        return torch.zeros_like(y), loops

    #x_active = x[mask]
    
    x_active=x
    l1a = x_active.sum()

    denom = (1.0 - (alpha * nu) / (l1a + EPS))
    #lam = 1.0 / denom if abs(float(denom)) > 1e-15 else torch.tensor(1.0, device=y.device, dtype=y.dtype)

    
    lam = 1.0 / denom
    
    d = torch.zeros_like(x)
    d_val = (l1a / max(nu, 1))
    d[mask] = d_val


    x = lam * x + (1.0 - lam) * d
    x = x * sgn
   
    
    if track_alpha or track_active:
        return x, loops, (alpha_history if track_alpha else None), (active_history if track_active else None)
    return x, loops

# ----------------------------------------------------------------
# CFCP BILEVEL (for matrices)
# ----------------------------------------------------------------
def cfcp_bilevel_hoyer_projection_matrices(W: torch.Tensor, l: float, max_iter: int = 4, eps: float = 1e-12):
    """
    Bilevel Hoyer projection for matrices (2D tensors).
    
    This applies the CFCP Hoyer projection to the column L2 norms,
    then scales each column accordingly.
    
    Args:
        W: Input matrix (2D tensor)
        l: Target Hoyer sparsity level
        max_iter: Max iterations for the inner Hoyer projection
        eps: Small constant for numerical stability
        
    Returns:
        (projected_matrix, loops): Projected matrix and iteration count
    """
    assert W.dim() == 2, "Expected 2D tensor (matrix)"
    device, dtype = W.device, W.dtype
    
    # Compute column L2 norms
    col_norms = torch.linalg.norm(W, ord=2, dim=0)
    
    # Apply CFCP Hoyer projection to the vector of norms
    projected_norms, loops = CFCP(
        col_norms, l, max_iter=max_iter
    )
    
    # Handle NaN values
    projected_norms = torch.where(
        torch.isnan(projected_norms),
        torch.zeros_like(projected_norms),
        projected_norms
    )
    
    # Vectorized per-column L2-ball projection
    # scale_i = min(1, projected_norm_i / ||col_i||_2)
    scales = torch.minimum(
        torch.ones_like(col_norms),
        projected_norms / (col_norms + eps)
    )
    
    # Apply scaling to each column
    result = W * scales.unsqueeze(0)
    
    return result, loops

# ----------------------------------------------------------------
# GSP (Grouped Sparse Projection) - Direct from GSP library
# ----------------------------------------------------------------
def gsp_projection_matrices(W: torch.Tensor, l: float, precision: float = 1e-6, linrat: float = 0.9):
    """
    Grouped Sparse Projection (GSP) for matrices - Direct implementation from GSP library.
    
    This is the hybrid Newton-Bisection method that combines Newton's method with 
    bisection fallback to solve for the Lagrange multiplier. This is the exact 
    implementation from the GSP repository.
    
    Args:
        W: Input matrix (2D tensor)
        l: Target Hoyer sparsity level
        precision: Convergence tolerance
        linrat: Linear rate guarantee parameter
        
    Returns:
        (projected_matrix, numiter): Projected matrix and iteration count
    """
    assert W.dim() == 2, "Expected 2D tensor (matrix)"
    device = W.device
    
    m, n = W.shape
    sqrt_m = float(np.sqrt(m))
    
    # Convert Hoyer level l to GSP sparsity parameter sps
    # GSP: sps = (sqrt(n) - L1/L2) / (sqrt(n) - 1)
    # Hoyer: target L1/L2 = sqrt(l)
    sqrt_l = np.sqrt(float(l))
    sps = (sqrt_m - sqrt_l) / (sqrt_m - 1.0)
    
    epsilon = 1e-15
    r = n  # Number of columns
    ni = m  # Number of rows
    
    # Target k for GSP method
    k = r * sqrt_m / (sqrt_m - 1.0) - r * sps
    
    # Sign handling
    matrix_sign = torch.sign(W)
    pos_matrix = matrix_sign * W
    
    # Beta parameter
    betai = 1.0 / (torch.sqrt(torch.tensor(ni, dtype=torch.float32, device=device)) - 1)
    
    # Check critical values
    max_elems = torch.max(pos_matrix, 0)[0]
    muup0 = float(max_elems.max() * (sqrt_m - 1.0))
    
    # Define gmu function for matrices (from GSP library)
    def gmu_matrix(pos_mat, mu):
        # Project: xp = [pos_mat - mu*beta]+
        xp = pos_mat - (mu * betai)
        xp.relu_()
        
        # Column norms
        mnorm = torch.norm(xp, dim=0)
        mnorm_inf = mnorm.clone()
        mnorm_inf[mnorm_inf == 0] = float("Inf")
        
        col_norm_mask = (mnorm > 0)
        nip = torch.sum(xp > 0, dim=0)
        
        # Gradient computation
        term2 = torch.pow(torch.sum(xp, dim=0), 2)
        mnorm_inv = torch.pow(mnorm_inf, -1)
        mnorm_inv3 = torch.pow(mnorm_inf, -3)
        
        gradg_mat = torch.pow(betai, 2) * (-nip * mnorm_inv + term2 * mnorm_inv3)
        gradg = torch.sum(gradg_mat)
        
        # Normalize xp for non-zero columns
        xp[:, col_norm_mask] /= mnorm[col_norm_mask]
        
        # Handle zero columns - set max element to 1
        max_elem_rows = torch.argmax(pos_mat, dim=0)[~col_norm_mask]
        xp[max_elem_rows, ~col_norm_mask] = 1
        
        # vgmu computation
        vgmu_mat = betai * torch.sum(xp, dim=0)
        vgmu = torch.sum(vgmu_mat)
        
        return vgmu, xp, gradg
    
    # Initial evaluation at mu=0
    vgmu, xp_mat, gradg = gmu_matrix(pos_matrix, 0.0)
    
    # Check if already at target sparsity
    if vgmu < k:
        return W.clone(), 0
    
    # Initialize bisection bounds
    mulow = 0.0
    muup = muup0
    glow = vgmu
    
    # Start with mu=0
    newmu = 0.0
    gnew = glow
    gpnew = gradg
    delta = muup - mulow
    
    numiter = 0
    
    # Main iteration loop (Hybrid Newton-Bisection)
    while abs(gnew - k) > precision * r and numiter < 100:
        oldmu = newmu
        
        # Newton step
        newmu = oldmu + (k - gnew) / (gpnew + epsilon)
        
        # If Newton goes out of bounds, use bisection
        if newmu >= muup or newmu <= mulow:
            newmu = (mulow + muup) / 2.0
            
        # Evaluate at new mu
        gnew, xnew, gpnew = gmu_matrix(pos_matrix, newmu)
        
        # Update bounds
        if gnew < k:
            muup = newmu
        else:
            mulow = newmu
            
        # Guarantee linear convergence (bisection fallback)
        if (muup - mulow) > linrat * delta and abs(oldmu - newmu) < (1 - linrat) * delta:
            newmu = (mulow + muup) / 2.0
            gnew, xnew, gpnew = gmu_matrix(pos_matrix, newmu)
            
            if gnew < k:
                muup = newmu
            else:
                mulow = newmu
                
            numiter += 1
            
        numiter += 1
    
    xp_mat = xnew
    
    # Final scaling using alpha (as in GSP library)
    alpha_mat = torch.matmul(xp_mat.T, pos_matrix)
    alpha = torch.diagonal(alpha_mat)
    xp_mat = alpha * (matrix_sign * xp_mat)
    
    return xp_mat, numiter


# ----------------------------------------------------------------
# Benchmark matrices (CFCP and HYBRID bilevel)
# ----------------------------------------------------------------
def bench_matrices(m: int, n: int, l: float, rep: int, device: torch.device, effective_rank: int = None):
    """
    Benchmark the CFCP and GSP bilevel Hoyer projection for matrices.
    
    Args:
        m: Number of rows
        n: Number of columns
        l: Target Hoyer sparsity level
        rep: Number of repetitions
        device: Torch device
        effective_rank: Effective rank for low-rank matrices (default: min(m,n)//10)
        
    Returns:
        Dictionary with timing and sparsity statistics
    """
    t_cfcp = []
    t_hybrid = []
    it_cfcp_loops = []
    it_hybrid = []
    l0_sparsities_cfcp = []
    l0_sparsities_hybrid = []
    hoyer_sparsities_cfcp = []
    hoyer_sparsities_hybrid = []
    diff_hybrid_vs_cfcp = []
    rel_hybrid_vs_cfcp = []
    
    # Set default effective rank if not specified
    if effective_rank is None:
        effective_rank = max(10, min(m, n) // 10)
    
    for rep_idx in range(rep):
        # Generate low-rank matrix using sklearn
        W_np = make_low_rank_matrix(
            n_samples=m, 
            n_features=n, 
            effective_rank=effective_rank,
            tail_strength=0.1,
            random_state=args.seed + rep_idx if 'args' in globals() else rep_idx
        )
        W = torch.from_numpy(W_np).float().to(device)
        
        # CFCP bilevel projection
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        W_cfcp, itF = cfcp_bilevel_hoyer_projection_matrices(W.clone(), l)
        if device.type == "cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
        
        t_cfcp.append((t1 - t0) * 1000.0)
        it_cfcp_loops.append(itF)
        
        # GSP projection
        if device.type == "cuda": torch.cuda.synchronize()
        t0 = time.perf_counter()
        W_gsp, itG = gsp_projection_matrices(W.clone(), l)
        if device.type == "cuda": torch.cuda.synchronize()
        t1 = time.perf_counter()
        
        t_hybrid.append((t1 - t0) * 1000.0)
        it_hybrid.append(itG)
        
        # Error vs CFCP (reference)
        norm_cfcp = torch.norm(W_cfcp, p='fro').item() + EPS
        diff_G = torch.norm(W_gsp - W_cfcp, p='fro').item()
        diff_hybrid_vs_cfcp.append(diff_G)
        rel_hybrid_vs_cfcp.append(diff_G / norm_cfcp)
        
        # Measure sparsity - CFCP
        l0_spar = float((W_cfcp.abs() < 1e-6).float().mean().item())
        l0_sparsities_cfcp.append(l0_spar)
        
        col_norms = torch.linalg.norm(W_cfcp, ord=2, dim=0)
        l1_norm = col_norms.sum()
        l2_norm = torch.sqrt((col_norms * col_norms).sum() + 1e-12)
        hoyer_spar = float(((l1_norm / l2_norm) ** 2 / n).item())
        hoyer_sparsities_cfcp.append(hoyer_spar)
        
        # Measure sparsity - GSP
        l0_spar = float((W_gsp.abs() < 1e-6).float().mean().item())
        l0_sparsities_hybrid.append(l0_spar)
        
        col_norms = torch.linalg.norm(W_gsp, ord=2, dim=0)
        l1_norm = col_norms.sum()
        l2_norm = torch.sqrt((col_norms * col_norms).sum() + 1e-12)
        hoyer_spar = float(((l1_norm / l2_norm) ** 2 / n).item())
        hoyer_sparsities_hybrid.append(hoyer_spar)
    
    def mean_std(a):
        a = np.asarray(a, dtype=float)
        mu = a.mean() if a.size else 0.0
        sd = a.std(ddof=1) if a.size > 1 else 0.0
        return float(mu), float(sd)
    
    return {
        "time_cfcp": mean_std(t_cfcp),
        "time_hybrid": mean_std(t_hybrid),
        "loops_cfcp": float(np.mean(it_cfcp_loops)),
        "iters_hybrid": float(np.mean(it_hybrid)),
        "l0_sparsity_cfcp": float(np.mean(l0_sparsities_cfcp)),
        "l0_sparsity_hybrid": float(np.mean(l0_sparsities_hybrid)),
        "hoyer_sparsity_cfcp": float(np.mean(hoyer_sparsities_cfcp)),
        "hoyer_sparsity_hybrid": float(np.mean(hoyer_sparsities_hybrid)),
        "diff_hybrid_vs_cfcp": float(np.mean(diff_hybrid_vs_cfcp)),
        "rel_hybrid_vs_cfcp": float(np.mean(rel_hybrid_vs_cfcp)),
    }

def main():
    parser = argparse.ArgumentParser(description="Benchmark CFCP vs GSP with Condat-simplex (sorted); errors vs CFCP.")
    # parser.add_argument("--ns", type=int, nargs="+", default=[1000,2000,3000,4000,5000], help="Vector sizes.")
    parser.add_argument("--ls", type=float, nargs="+", default=[800], help="Hoyer levels l to test.")
    parser.add_argument("--rep", type=int, default=100, help="Repetitions per (n,l).")
    parser.add_argument("--seed", type=int, default=0, help="Random seed.")
    parser.add_argument("--bench_matrices", action='store_true', help="Also benchmark matrices.", default=True)
    parser.add_argument("--matrix_sizes", type=int, nargs="+", default=[1000, 2000, 3000, 4000, 5000], help="Matrix sizes (m x m) to test.")
    parser.add_argument("--effective_ranks", type=int, nargs="+", default=[10, 100], help="Effective ranks for low-rank matrices.")
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")  # Force CPU for consistency
    print(f"Device: {device}\n\nls={args.ls}\nrep={args.rep}\n")

    # ========== MATRIX BENCHMARKS ==========
    if args.bench_matrices:
        print("\n" + "="*80)
        print("MATRIX BENCHMARKS (CFCP vs GSP Methods)")
        print("Using LOW-RANK matrices for fair comparison")
        print(f"Effective ranks: {args.effective_ranks}")
        print("="*80)
        
        # Store results for all ranks: {l: {rank: [(size, res), ...]}}
        all_results = {}
        
        for l in args.ls:
            all_results[l] = {}
            print(f"\n{'='*80}")
            print(f"=== Matrix benchmarks for l = {l} ===")
            print(f"{'='*80}")
            
            for rank in args.effective_ranks:
                print(f"\n--- Effective rank: {rank} ---")
                matrix_rows = []
                
                for size in args.matrix_sizes:
                    m, n = size, size  # Square matrices
                    res = bench_matrices(m, n, l, args.rep, device, effective_rank=rank)
                    matrix_rows.append((size, res))
                    print(
                        f"\nMatrix size {m}x{n} | "
                        f"CFCP {res['time_cfcp'][0]:.2f}±{res['time_cfcp'][1]:.2f} ms | "
                        f"GSP {res['time_hybrid'][0]:.2f}±{res['time_hybrid'][1]:.2f} ms | "
                        #f"Loops/Iters: CFCP={res['loops_cfcp']:.2f}, GSP={res['iters_hybrid']:.2f} | "
                        #f" L0: CFCP={res['l0_sparsity_cfcp']:.4f}, GSP={res['l0_sparsity_hybrid']:.4f} | "
                        #f"Hoyer: CFCP={res['hoyer_sparsity_cfcp']:.6f}, GSP={res['hoyer_sparsity_hybrid']:.6f}"
                    )
                
                all_results[l][rank] = matrix_rows
        
        # Plot 1: CFCP vs GSP
        for l in args.ls:
            for rank in args.effective_ranks:
                plt.figure(figsize=(10, 6))
                rows = all_results[l][rank]
                sizes = [r[0] for r in rows]
                times_cfcp = [r[1]["time_cfcp"][0] for r in rows]
                times_gsp = [r[1]["time_hybrid"][0] for r in rows]
                
                plt.plot(sizes, times_cfcp, marker="o", linewidth=2, markersize=8, label="CFCP")
                plt.plot(sizes, times_gsp, marker="d", linewidth=2, markersize=8, label="GSP")
                
                plt.xlabel("Matrix size (m=n)", fontsize=12)
                plt.ylabel("Avg time (ms)", fontsize=12)
                plt.title(f"Matrix Projection Time: CFCP vs GSP\n(l={int(l)})", fontsize=14)
                plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
                plt.legend(fontsize=11)
                plt.tight_layout()
                
                filename = f'matrix_time_l{int(l)}_rank{rank}.png'
                plt.savefig(filename, dpi=300, bbox_inches='tight')
                print(f"\nSaved plot: {filename}")
                plt.show()
        
        # Plot 2: Compare different ranks for a fixed matrix size (middle size)
        mid_idx = len(args.matrix_sizes) // 2
        mid_size = args.matrix_sizes[mid_idx]
        
        for l in args.ls:
            plt.figure()
            
            times_cfcp_by_rank = []
            times_gsp_by_rank = []
            
            for rank in args.effective_ranks:
                rows = all_results[l][rank]
                res = rows[mid_idx][1]  # Get results for middle size
                times_cfcp_by_rank.append(res["time_cfcp"][0])
                times_gsp_by_rank.append(res["time_hybrid"][0])
            
            x = np.arange(len(args.effective_ranks))
            width = 0.35
            
            plt.bar(x - width/2, times_cfcp_by_rank, width, label='CFCP', alpha=0.8)
            plt.bar(x + width/2, times_gsp_by_rank, width, label='GSP', alpha=0.8)
            
            plt.xlabel('Effective Rank', fontsize=12)
            plt.ylabel('Avg time (ms)', fontsize=12)
            plt.title(f'Time vs Effective Rank: CFCP vs GSP\n(l={int(l)}, matrix size={mid_size}x{mid_size})', fontsize=14)
            plt.xticks(x, args.effective_ranks)
            plt.legend(fontsize=11)
            plt.grid(True, axis='y', linestyle='--', linewidth=0.5, alpha=0.7)
            plt.tight_layout()
            
            filename = f'time_vs_rank_l{int(l)}_size{mid_size}.png'
            plt.savefig(filename, dpi=300, bbox_inches='tight')
            print(f"\nSaved plot: {filename}")
            plt.show()
    
        
        print("\nMatrix benchmarks complete.")

if __name__ == "__main__":
    
    main()
