import torch

def alternating_least_squares(subset,name,dev,spectral_factor,i,ratio,max_iter,tau,rho):
    """
    Alternating Least Squares (ALS) for low-rank factorization of weight matrices.
    Given a weight W and profiling matrix N (from SIMT), iteratively updates U and V
    such that W ≈ U V under activation-aware metric, with ridge (tau) and proximal (rho) regularization.
    """
    # Original weight matrix of the given projection
    W = subset[name].weight.data.float().to(dev) 
    # spectral_factor matrix from SIMT corresponding to this layer/projection
    N = spectral_factor[i][name].to(dev).float() 
    # Precompute product W * N (activation-aware reweighting)
    M = torch.matmul(W, N) 
    # Compute target rank r based on global compression ratio
    num_s_after_trunc = int(W.shape[0] * W.shape[1] * ratio / (W.shape[0] + W.shape[1]))
    # SVD initialization of W
    U, S, VT = torch.linalg.svd(W, full_matrices=False) 
    truc_s = S[:num_s_after_trunc]       # Top-r singular values
    truc_u = U[:, :num_s_after_trunc]    # Top-r left singular vectors
    truc_v = VT[:num_s_after_trunc, :]   # Top-r right singular vectors
    # Initialize U and V with square-rooted singular values (better conditioning)
    sqrtSigma = torch.diag(torch.sqrt(truc_s)).to(dev)
    svd_u = torch.matmul(truc_u, sqrtSigma)   # U * sqrt(Σ)
    svd_v = torch.matmul(sqrtSigma, truc_v)   # sqrt(Σ) * Vᵀ

    # ---------------- ALS iterations ----------------
    for t in range(max_iter):
        # -------- Update U (A-step) --------
        S = svd_v@N
        # Cholesky solve for ridge+proximal regularized least squares
        L = torch.linalg.cholesky(S@S.T+(tau+rho)*torch.eye(S.shape[0],device=dev)) 
        R = S@M.T + rho*svd_u.T
        # Solve (L Lᵀ) X = R for X using triangular solves
        Y = torch.linalg.solve_triangular(L,R,upper = False) 
        X = torch.linalg.solve_triangular(L.T,Y,upper = True)
        svd_u = X.T

        # -------- Update V (B-step) --------
        # SVD decompositions of N and current U
        U_N, S_N, Vt_N = torch.linalg.svd(N, full_matrices=False) 
        U_A, S_A, Vt_A = torch.linalg.svd(svd_u, full_matrices=False) 
        # Compute intermediate matrix C with proximal regularization
        C = torch.diag(S_A)@U_A.T@M@Vt_N.T@torch.diag(S_N) + rho * Vt_A@svd_v@U_N
        # Closed-form update of B (ALS step), elementwise division
        Bi_pie = C / (torch.outer(S_A**2, S_N**2) +  tau + rho)
        # Recover new V from updated factors
        svd_v = Vt_A.T@Bi_pie@U_N.T
        # Release temporary variables to save memory
        S=L=R=Y=X=U_N=S_N= Vt_N=U_A= S_A= Vt_A=C=Bi_pie = None
        del  S,L,R,Y,X,U_N,S_N, Vt_N,U_A, S_A, Vt_A,C,Bi_pie
    # Release large matrices no longer needed
    W = num_s_after_trunc=U= S= VT=truc_s=truc_u=truc_v = sqrtSigma = None
    del  W,  num_s_after_trunc,U,S, VT,truc_s,truc_u,truc_v ,sqrtSigma

    N = M = None
    del  N,M 
    # Return learned low-rank factors (U, V)
    return svd_u,svd_v