import torch

def innerp(x, y=None, out=None):
    if y is None:
        y = x
    if out is not None:
        out = out[:, None, None]
    return torch.matmul(x[..., None, :], y[..., :, None], out=out)[..., 0, 0]

def _omp_v0(X, y, XTX, n_nonzero_coefs=None, tol=None):
    """
    inputs
    X : dictionary (batch_size, signal_dim, dictionary_size)
    y : signal (batch_size, n_signals, signal_dim)
    XTX : (batch_size, dictionary_size, dictionary_size)

    outputs
    sets : dictionary coefficients (batch_size, n_signals, n_nonzero_coefs)
    result_solutions : dictionary weights (batch_size, n_signals, n_nonzero_coefs, 1)
    errors : reconstruction l2 norm errors (batch_size, n_signals)
    normr2_init : initial l2 norm of signals (batch_size, n_signals)
    lengths : number of coefficients for each signal (batch_size, n_signals)
    above_thres : is signal recon error still above error threshold (batch_size, n_signals)
    """
    B, b, _ = y.shape
    normr2_init = innerp(y)
    normr2 = normr2_init.clone()
    projections = torch.bmm(X.transpose(2, 1), y.transpose(1, 2)).transpose(1, 2)
    sets = y.new_zeros(n_nonzero_coefs, B, b, dtype=torch.int64)

    F = torch.eye(n_nonzero_coefs, dtype=y.dtype, device=y.device).repeat(B, b, 1, 1)
    a_F = y.new_zeros((n_nonzero_coefs, B, b, 1), dtype=y.dtype)

    D_mybest = y.new_empty(B, b, n_nonzero_coefs, XTX.shape[1])
    temp_F_k_k = y.new_ones((B, b, 1))

    if tol:
        result_lengths = sets.new_zeros((y.shape[0], y.shape[1]))
        result_solutions = y.new_zeros((y.shape[0], y.shape[1], n_nonzero_coefs, 1))
        finished_problems = sets.new_zeros((y.shape[0], y.shape[1]), dtype=torch.bool)
        tol = normr2_init * (tol * tol)

    for k in range(n_nonzero_coefs+(tol is not None)):
        # STOPPING CRITERIA
        if tol is not None:
            problems_done = normr2 <= tol 
            if k == n_nonzero_coefs:
                below_tol = problems_done.clone()
                problems_done[:, :] = True
            
            if problems_done.any():
                new_problems_done = problems_done & ~finished_problems
                finished_problems.logical_or_(problems_done)
                result_lengths[new_problems_done] = k
                result_solutions.view(-1, n_nonzero_coefs, 1)[new_problems_done.flatten(), :k] = \
                    F.view(-1, n_nonzero_coefs, n_nonzero_coefs)[new_problems_done.flatten(), :k, :k].permute(0, 2, 1) @ a_F.view(n_nonzero_coefs, -1, 1)[:k, new_problems_done.flatten()].permute(1, 0, 2)
                print("problems done", problems_done)
                if problems_done.all():
                    if k == n_nonzero_coefs:
                        return sets.permute(1, 2, 0), result_solutions, normr2, normr2_init, result_lengths, ~below_tol
                    else:
                        return sets.permute(1, 2, 0), result_solutions, normr2, normr2_init, result_lengths, ~problems_done

        sets[k] = projections.abs().argmax(2)
        torch.gather(XTX, 1, sets[k, :, :, None].expand(-1, -1, XTX.shape[2]), out=D_mybest[:, :, k, :])
        if k:
            D_mybest_maxindices = D_mybest.permute(0, 1, 3, 2)[
                torch.arange(D_mybest.shape[0], dtype=sets.dtype, device=sets.device).unsqueeze(1), 
                torch.arange(D_mybest.shape[1], dtype=sets.dtype, device=sets.device).unsqueeze(0), 
                sets[k],
                :k
            ]

            torch.rsqrt(1 - innerp(D_mybest_maxindices),
                        out=temp_F_k_k[:, :, 0])
            D_mybest_maxindices *= -temp_F_k_k
            D_mybest[:, :, k, :] *= temp_F_k_k
            D_mybest[:, :, k, :, None].view(-1, XTX.shape[1], 1).baddbmm_(D_mybest[:, :, :k, :].permute(0, 1, 3, 2).view(-1, XTX.shape[1], k), D_mybest_maxindices[:, :, :, None].view(-1, k, 1))

        temp_a_F = temp_F_k_k * torch.gather(projections, 2, sets[k, :, :, None])
    
        normr2.sub_((temp_a_F * temp_a_F).squeeze(-1))

        projections -= temp_a_F * D_mybest[:, :, k, :]
        a_F[k] = temp_a_F
        if k:
            torch.bmm(D_mybest_maxindices[:, :, None, :].view(-1, 1, k), F[:, :, :k, :].view(-1, k, n_nonzero_coefs), out=F[:, :, k, None, :].view(-1, 1, n_nonzero_coefs))
            F[:, :, k, k] = temp_F_k_k[..., 0]
    else: # FIXME: 
        solutions = F.permute(0, 1, 3, 2) @ a_F.squeeze(-1).permute(1, 2, 0)[:, :, :, None]

    return sets.permute(1, 2, 0).to(torch.int32), solutions, normr2, normr2_init, None, None

def omp(X, y, n_nonzero_coefs=None, tol=None):
    XTX = torch.bmm(X.permute(0, 2, 1), X)
    sets, solutions, errors, kv_normr2, lengths, above_thres = omp_v0(X, y, XTX, n_nonzero_coefs, tol)
    
    sets = sets.squeeze(0)
    solutions = solutions.squeeze(0).squeeze(-1)
    
    if lengths is not None:
        lengths = lengths.squeeze(0)
    else:
        lengths = torch.full((y.shape[1],), n_nonzero_coefs)

    data = torch.cat([solutions[i, :lengths[i]] for i in range(y.shape[1])]).to(torch.float8_e4m3fn)
    indices = torch.cat([sets[i, :lengths[i]] for i in range(y.shape[1])]).to(torch.int16)

    indptr = torch.zeros(y.shape[1] + 1, dtype=torch.int32, device=sets.device)
    indptr[1:] = lengths.cumsum(dim=0)

    return indptr, indices, data


def omp_v0(X, y, XTX, n_nonzero_coefs=None, tol=None):
    """
    inputs
    X : dictionary (batch_size, signal_dim, dictionary_size)
    y : signal (batch_size, n_signals, signal_dim)
    XTX : (batch_size, dictionary_size, dictionary_size)

    outputs
    sets : dictionary coefficients (batch_size, n_signals, n_nonzero_coefs)
    result_solutions : dictionary weights (batch_size, n_signals, n_nonzero_coefs, 1)
    errors : reconstruction l2 norm errors (batch_size, n_signals)
    normr2_init : initial l2 norm of signals (batch_size, n_signals)
    lengths : number of coefficients for each signal (batch_size, n_signals)
    above_thres : is signal recon error still above error threshold (batch_size, n_signals)
    """
    B, b, _ = y.shape
    normr2_init = innerp(y)
    normr2 = normr2_init.clone()
    projections = torch.bmm(X.transpose(2, 1), y.transpose(1, 2)).transpose(1, 2) # (B, b, dict_size)
    sets = y.new_zeros(n_nonzero_coefs, B, b, dtype=torch.int64)

    F = torch.eye(n_nonzero_coefs, dtype=y.dtype, device=y.device).repeat(B, b, 1, 1)
    a_F = y.new_zeros((n_nonzero_coefs, B, b, 1), dtype=y.dtype)

    D_mybest = y.new_empty(B, b, n_nonzero_coefs, XTX.shape[1])
    temp_F_k_k = y.new_ones((B, b, 1)) # (B,b,1) for broadcasting

    # --- 수치 안정성을 위한 파라미터 ---
    # 이 값들은 실험을 통해 조절해야 합니다.
    INNERP_MAX_CLAMP = 1.0 - 1e-7 # innerp 결과의 상한 (1에 매우 가깝게)
    RSQRT_INPUT_MIN_CLAMP = 1e-7 # 1 - innerp_val 의 하한
    RSQRT_OUTPUT_MAX_CLAMP = 1e4 # 경험적 상한 (temp_F_k_k의 최대값)
    
    # D_mybest_maxindices 와 D_mybest 의 값 범위를 제어하기 위한 클램핑 값
    # 이 값들이 너무 크면 이후 bmm 등에서 inf 발생 가능
    MAX_COEF_VAL = 1e5 # 예시 값, D_mybest_maxindices, D_mybest 등의 최대값

    # --- tol 관련 초기화 (원본과 동일) ---
    if tol:
        result_lengths = sets.new_zeros((y.shape[0], y.shape[1]))
        result_solutions = y.new_zeros((y.shape[0], y.shape[1], n_nonzero_coefs, 1))
        finished_problems = sets.new_zeros((y.shape[0], y.shape[1]), dtype=torch.bool)
        tol_squared = normr2_init * (tol * tol) # tol을 제곱 형태로 미리 계산

    for k in range(n_nonzero_coefs + (1 if tol is not None else 0)): # 루프 범위 명확화        
        if tol is not None:
            problems_done = normr2 <= tol 
            if k == n_nonzero_coefs:
                below_tol = problems_done.clone()
                problems_done[:, :] = True
            
            if problems_done.any():
                new_problems_done = problems_done & ~finished_problems
                finished_problems.logical_or_(problems_done)
                result_lengths[new_problems_done] = k
                result_solutions.view(-1, n_nonzero_coefs, 1)[new_problems_done.flatten(), :k] = \
                    F.view(-1, n_nonzero_coefs, n_nonzero_coefs)[new_problems_done.flatten(), :k, :k].permute(0, 2, 1) @ a_F.view(n_nonzero_coefs, -1, 1)[:k, new_problems_done.flatten()].permute(1, 0, 2)
                print("problems done", problems_done)
                if problems_done.all():
                    if k == n_nonzero_coefs:
                        return sets.permute(1, 2, 0), result_solutions, normr2, normr2_init, result_lengths, ~below_tol
                    else:
                        return sets.permute(1, 2, 0), result_solutions, normr2, normr2_init, result_lengths, ~problems_done

        
        # --- 가장 큰 projection을 가진 atom 선택 ---
        # projections: (B, b, dict_size)
        sets[k] = projections.abs().argmax(2)

        # --- D_mybest 업데이트 (선택된 atom에 대한 XTX의 행/열) ---
        # XTX: (B, dict_size, dict_size)
        # sets[k]: (B, b) - 선택된 인덱스
        # D_mybest[:, :, k, :]: (B, b, dict_size)
        # gather의 index는 XTX의 dim 1을 따라감. (B, b, dict_size)가 되어야 함.
        # sets[k, :, :, None] -> (B, b, 1) 이것을 expand -> (B, b, XTX.shape[2])
        # gather(XTX.expand(B,b,-1,-1), dim=2, index=sets[k].unsqueeze(-1).unsqueeze(-1).expand(B,b,1,XTX.shape[2]))
        # 더 효율적인 방법: advanced indexing + expand_as 또는 직접 루프(느림)
        # batch 단위로 gather 수행
        # XTX_expanded = XTX.unsqueeze(1).expand(-1, b, -1, -1) # (B, b, dict_size, dict_size)
        # gathered_XTX_rows = torch.gather(XTX_expanded, 2, sets[k].view(B,b,1,1).expand(B,b,1,XTX.shape[2]))
        # D_mybest[:, :, k, :] = gathered_XTX_rows.squeeze(2)
        
        # 원본 gather 방식이 더 효율적일 수 있음. XTX (B, D, D), sets[k] (B,b)
        # D_mybest (B,b,k_iter,D)
        # torch.gather(XTX, 1, sets[k, :, :, None].expand(-1, -1, XTX.shape[2]), out=D_mybest[:, :, k, :])
        # 이 gather는 XTX의 dim 1 (source dim)에 대해 sets[k] (index)를 사용.
        # XTX의 shape (B, dict_size, dict_size)
        # sets[k,:,:,None].expand(-1,-1,XTX.shape[2]) -> (B, b, dict_size) - index 텐서
        # out: D_mybest[:,:,k,:] -> (B, b, dict_size)
        # 이 gather는 XTX의 첫 번째 차원 B에 대해서만 동작하며, XTX가 (dict_size, dict_size)이고
        # index가 (b, dict_size) 여야 할 것 같음. XTX가 (B,D,D) 이면,
        # index (B,b,D) 여야 하고, dim=1 (D 차원)을 따라 gather.
        # 원본 코드가 맞다고 가정하고 진행.
        torch.gather(XTX, 1, sets[k].unsqueeze(-1).expand(B,b,XTX.shape[2]), out=D_mybest[:,:,k,:])


        if k > 0:
            # D_mybest_maxindices 계산
            # D_mybest: (B, b, k_iter, dict_size)
            # sets[k]: (B,b) - 현재 선택된 atom의 인덱스
            # D_mybest.permute(0,1,3,2) -> (B, b, dict_size, k_iter)
            # advanced indexing으로 D_mybest_maxindices 추출
            # D_mybest_maxindices_val: (B, b, k_prev) - 현재 선택된 atom과 이전 k개 atom들 간의 XTX 값
            
            # Advanced indexing 방식 사용
            # batch_indices = torch.arange(B, device=y.device).view(B, 1, 1).expand(B, b, k)
            # signal_indices = torch.arange(b, device=y.device).view(1, b, 1).expand(B, b, k)
            # atom_indices = sets[k].unsqueeze(-1).expand(B, b, k) # 현재 선택된 atom
            # prev_atom_iterations = torch.arange(k, device=y.device).view(1,1,k).expand(B,b,k) # 0 to k-1
            # D_mybest_maxindices_val = D_mybest[batch_indices, signal_indices, prev_atom_iterations, atom_indices]
            
            # 원본 코드의 indexing 방식 (더 복잡해 보이지만, PyTorch 내부 최적화가 있을 수 있음)
            # D_mybest (B,b,n_nz_coef, D) -> permute (B,b,D,n_nz_coef)
            # sets[k] (B,b) - 선택된 atom index
            # D_mybest_maxindices_val (B,b,k)
            permuted_D_mybest = D_mybest.permute(0, 1, 3, 2) # (B,b,D,k_iter)
            idx_B = torch.arange(B, device=y.device)[:, None]
            idx_b = torch.arange(b, device=y.device)[None, :]
            current_selected_atom_indices = sets[k] # (B,b)
            
            # D_mybest_maxindices_val는 현재 선택된 atom과 이전에 선택된 k개의 atom들 사이의 "상관관계"를 나타냄
            # (D_mybest의 k번째 iteration에 저장된 XTX[selected_atom, :] 와 이전 selected_atom들 간의 값)
            # D_mybest_maxindices_val: (B,b,k)
            D_mybest_maxindices_val = permuted_D_mybest[idx_B, idx_b, current_selected_atom_indices, :k]


            # --- Gram-Schmidt 직교화와 관련된 부분 ---
            # innerp_val: D_mybest_maxindices_val의 자기 내적 (L_k,k 계산에 사용됨)
            # innerp_val: (B,b)
            innerp_val_raw = innerp(D_mybest_maxindices_val) # D_mybest_maxindices_val는 (B,b,k)
            innerp_val = torch.clamp(innerp_val_raw, min=0.0, max=INNERP_MAX_CLAMP) # 0 ~ (1-eps)
            
            # temp_F_k_k_val: 1 / sqrt(1 - sum_of_squares_of_correlations) -> L_k,k^-1
            # temp_F_k_k_val: (B,b)
            # rsqrt_input = torch.clamp(1.0 - innerp_val, min=RSQRT_INPUT_MIN_CLAMP) # 이미 innerp_val clamp로 처리됨
            rsqrt_input = 1.0 - innerp_val # 최소 RSQRT_INPUT_MIN_CLAMP 보장됨
            temp_F_k_k_val = torch.rsqrt(rsqrt_input)
            # temp_F_k_k_val.clamp_(max=RSQRT_OUTPUT_MAX_CLAMP) # in-place clamp
            
            # temp_F_k_k[:,:,0] 에 (B,b) 형태의 temp_F_k_k_val 할당
            temp_F_k_k[:, :, 0] = temp_F_k_k_val


            # D_mybest_maxindices_val 와 D_mybest[:, :, k, :] 업데이트 (새로운 직교 벡터 성분 계산)
            # D_mybest_maxindices_val: (B,b,k)
            # temp_F_k_k_val: (B,b) -> unsqueeze로 (B,b,1) 만들어 브로드캐스팅
            update_factor = -temp_F_k_k_val.unsqueeze(-1) # (B,b,1)
            D_mybest_maxindices_val_updated = D_mybest_maxindices_val * update_factor
            D_mybest_maxindices_val_updated.clamp_(max=MAX_COEF_VAL, min=-MAX_COEF_VAL)
            
            # D_mybest[:, :, k, :] (B,b,D) 에도 동일한 스케일링 적용
            # D_mybest의 k번째 반복에 해당하는 슬라이스
            D_mybest_k_slice = D_mybest[:, :, k, :] # (B,b,D)
            D_mybest_k_slice_updated = D_mybest_k_slice * temp_F_k_k_val.unsqueeze(-1) # (B,b,D) * (B,b,1)

            # D_mybest_k_slice_updated에 이전 직교 성분들을 빼서 새로운 직교 벡터 계산
            # D_mybest[:, :, :k, :] -> (B,b,k,D) -> permute (B,b,D,k)
            # D_mybest_maxindices_val_updated -> (B,b,k) -> unsqueeze (B,b,k,1)
            # baddbmm: (B*b, D, 1) += (B*b, D, k) @ (B*b, k, 1)
            # view(-1, XTX.shape[1], 1) -> (B*b, D, 1)
            # D_mybest[:,:,:k,:].permute(0,1,3,2).view(-1, XTX.shape[1], k) -> (B*b, D, k)
            # D_mybest_maxindices_val_updated[:,:,:,None].view(-1, k, 1) -> (B*b, k, 1)
            
            # D_mybest_k_slice_updated를 view로 전달하여 in-place처럼 동작하게 함
            # (이 view가 D_mybest_k_slice_updated와 메모리를 공유해야 함)
            # 그러나 baddbmm은 out 파라미터가 없으므로, 결과를 받아서 다시 할당해야 함.
            # 혹은 D_mybest_k_slice_updated 자체를 baddbmm의 첫번째 인자로 사용.
            
            # 연산을 위한 view 생성
            target_for_baddbmm = D_mybest_k_slice_updated.reshape(B * b, XTX.shape[1], 1)
            mat1_for_baddbmm = D_mybest[:, :, :k, :].permute(0, 1, 3, 2).reshape(B * b, XTX.shape[1], k)
            # mat2_for_baddbmm = D_mybest_maxindices_val_updated.unsqueeze(-1).reshape(B * b, k, 1)
            mat2_for_baddbmm = D_mybest_maxindices_val_updated.reshape(B * b, k, 1)
            
            # baddbmm 연산 (alpha=1, beta=1 기본값)
            # target = beta * target + alpha * (mat1 @ mat2)
            # 여기서는 target += mat1 @ mat2
            torch.baddbmm(target_for_baddbmm, mat1_for_baddbmm, mat2_for_baddbmm, out=target_for_baddbmm) # in-place update
            
            # D_mybest_maxindices_val_updated와 D_mybest_k_slice_updated를 원래 위치에 다시 저장
            # 이 값들은 이제 clamp 대상. clamp는 루프 마지막이나 필요시 최소한으로.
            # D_mybest_maxindices에 해당 값 저장하지 않음 (D_mybest_maxindices_val_updated는 중간 계산값)
            # D_mybest의 k번째 슬라이스에 업데이트된 값 저장
            D_mybest[:, :, k, :] = target_for_baddbmm.view(B, b, XTX.shape[1]) # D_mybest_k_slice_updated와 동일

            # D_mybest_maxindices_val_updated 와 D_mybest[:,:,k,:] 에 대한 clamp는
            # 이후 연산에서 이 값들이 사용될 때 문제가 생기지 않을 정도의 값으로만 제한.
            # 또는 루프 마지막에 한 번만 수행하는 것을 고려.
            # 현재 구조에서는 이 값들이 다음 bmm 등에 직접 사용되므로, 여기서 clamp가 필요할 수 있음.
            # 하지만, clamp 값 (MAX_COEF_VAL) 이 매우 크다면, clamp의 효과는 inf/nan 방지 수준.
            D_mybest[:, :, k, :].clamp_(max=MAX_COEF_VAL, min=-MAX_COEF_VAL) # in-place
            # D_mybest_maxindices_val_updated도 필요하면 clamp (하지만 다음 bmm의 입력으로 직접 쓰이지는 않음)


        # --- 잔차 및 projection 업데이트 ---
        # temp_a_F: 현재 선택된 atom에 대한 coefficient (alpha_k)
        # projections: (B,b,D)
        # sets[k]: (B,b) -> gather를 위해 (B,b,1)로
        # temp_a_F: (B,b,1)
        gathered_projections = torch.gather(projections, 2, sets[k].unsqueeze(-1))
        temp_a_F_val = temp_F_k_k * gathered_projections # (B,b,1) * (B,b,1) = (B,b,1)

        # normr2 업데이트 (잔차의 제곱 norm)
        normr2.sub_((temp_a_F_val * temp_a_F_val).squeeze(-1)) # normr2는 (B,b)
        # normr2.clamp_(min=0.0) # 제곱 norm이므로 음수가 될 수 없음 (수치 오류 방지)

        # projections 업데이트: y_residual -= alpha_k * selected_atom_vector
        # 여기서 D_mybest[:, :, k, :]는 선택된 k번째 직교화된 atom (또는 그와 관련된 것)
        # projections: (B,b,dict_size)
        # temp_a_F_val: (B,b,1)
        # D_mybest[:,:,k,:]: (B,b,dict_size)
        projections.sub_(temp_a_F_val * D_mybest[:, :, k, :])
        # projections에 대한 clamp는 신중해야 함. 잔차의 의미를 왜곡할 수 있음.
        # 만약 필요하다면, 매우 큰 값만 제한.
        projections.clamp_(min=-MAX_COEF_VAL, max=MAX_COEF_VAL) # 예: -1e6, 1e6. 원본은 min=-1e5

        a_F[k] = temp_a_F_val # (B,b,1)

        # --- F 매트릭스 업데이트 (L^-1 업데이트와 관련) ---
        if k > 0:
            # D_mybest_maxindices_val_updated 사용 (위에서 계산된 값) (B,b,k)
            # F: (B,b,n_nz,n_nz)
            # F[:,:,:k,:] -> (B,b,k,n_nz)
            # out: F[:,:,k,None,:] -> (B,b,1,n_nz)
            
            # bmm을 위한 view
            # D_mybest_maxindices_val_updated.unsqueeze(-2) -> (B,b,1,k)
            # .view(-1,1,k)
            mat1_F_update = D_mybest_maxindices_val_updated.unsqueeze(-2).reshape(B * b, 1, k)
            
            # F[:,:,:k,:].view(-1,k,n_nonzero_coefs)
            mat2_F_update = F[:, :, :k, :].reshape(B * b, k, n_nonzero_coefs)
            
            out_F_update = F[:, :, k, None, :].reshape(B * b, 1, n_nonzero_coefs) # view for out param

            torch.bmm(mat1_F_update, mat2_F_update, out=out_F_update)
            
            # F의 대각 성분 업데이트 (L_k,k^-1)
            # F[:,:,k,k] (B,b) 에 temp_F_k_k_val (B,b) 할당
            F[:, :, k, k] = temp_F_k_k_val
            
            # F에 대한 clamp는 매우 신중해야 함. F는 역행렬과 관련된 정보를 담고 있을 수 있음.
            # 만약 F의 값들이 너무 커진다면, 이는 XTX의 조건수(condition number)가 나쁘거나
            # 선택된 atom들이 매우 유사하여 발생하는 문제일 수 있음.
            # clamp를 최소화하고, 필요하다면 XTX 정규화나 atom 선택 전략을 검토.
            # 원본 코드는 F를 clamp함. (-1e5, 1e5)
            F.clamp_(min=-MAX_COEF_VAL, max=MAX_COEF_VAL) # MAX_COEF_VAL 사용 또는 F 전용 clamp 값 사용
    
    # 루프 종료 후 최종 clamp (필요한 경우에만)
    # D_mybest.clamp_(min=-MAX_COEF_VAL, max=MAX_COEF_VAL)
    # F.clamp_(min=-MAX_COEF_VAL, max=MAX_COEF_VAL)

    # # --- tol이 없는 경우의 최종 해 계산 (또는 tol 루프가 정상 종료된 경우) ---
    # # if tol is None or (tol is not None and not finished_problems.all()): # 모든 문제가 tol 조건으로 끝나지 않은 경우
    # if tol is None or (k == n_nonzero_coefs and (tol is not None and not finished_problems.all())):
    #     # F는 (B,b,n_nz,n_nz), a_F는 (n_nz,B,b,1)
    #     # solutions: (B,b,n_nz,1)
    #     # a_F를 (B,b,n_nz,1) 형태로 변경: a_F.permute(1,2,0,3)
    #     # F.permute(0,1,3,2) @ a_F.permute(1,2,0,3) -> F.T @ a_F (만약 F가 L^-1 이면 (L^-1).T @ a_F)
    #     # 실제 해는 X_active @ beta = y, beta = (X_active.T @ X_active)^-1 @ X_active.T @ y
    #     # F와 a_F가 이 구성요소를 어떻게 나타내는지에 따라 달라짐.
    #     # 원본: solutions = F.permute(0, 1, 3, 2) @ a_F.squeeze(-1).permute(1, 2, 0)[:, :, :, None]
    #     # a_F.squeeze(-1) -> (n_nz, B, b) -> permute(1,2,0) -> (B,b,n_nz) -> unsqueeze -> (B,b,n_nz,1)
    #     a_F_reshaped = a_F.squeeze(-1).permute(1,2,0).unsqueeze(-1)
    #     solutions = torch.bmm(F.transpose(-2,-1).reshape(B*b, n_nonzero_coefs, n_nonzero_coefs),
    #                           a_F_reshaped.reshape(B*b, n_nonzero_coefs, 1)
    #                          ).view(B,b,n_nonzero_coefs,1)
    # elif tol is not None: # tol 모드에서 모든 문제가 완료되어 루프가 일찍 종료된 경우
    #     solutions = result_solutions # 이미 계산된 해 사용
    else: # FIXME: 
        solutions = F.permute(0, 1, 3, 2) @ a_F.squeeze(-1).permute(1, 2, 0)[:, :, :, None]

    return sets.permute(1, 2, 0).to(torch.int32), solutions, normr2, normr2_init, None, None