import math
import os, threading

def analysis_diff(original, approximation):
    original_f = original.float()
    approximation_f = approximation.float()
    norm_original = torch.linalg.norm(original_f)
    if norm_original == 0:
        return 0.0
    norm_diff = torch.linalg.norm(original_f - approximation_f)
    rel_err_x_vs_wq = (norm_diff / norm_original).item()
    return rel_err_x_vs_wq


def bit_plane_decomposition(q, bits: int):
    w_int = q.to(torch.int32)
    planes = [((w_int >> i) & 1).to(torch.uint8) for i in range(bits)]
    return torch.stack(planes, dim=0)  # [bits, B, G] uint8


def solve_init_C(w_ori, basis_plane, num_plane, alpha=1e-4, dtype_store=torch.float16, dtype_acc=torch.float32, U_g=None):
    device = w_ori.device
    B, group_size = w_ori.shape                                               
    n_features = num_plane + 1                                                 
    ones_batched = torch.ones(B, group_size, 1, device=device, dtype=dtype_acc)
    I = torch.eye(n_features, device=device, dtype=dtype_acc).expand(B, n_features, n_features)
    U = U_g.to(dtype=dtype_acc)

    b_stacked_u8 = basis_plane.permute(1, 2, 0)                               # [B, group_size, num_plane]
    X_batched = torch.cat((b_stacked_u8.to(dtype=dtype_acc), ones_batched), dim=2)  # [B, group_size, n_features]

    X_rhs = X_batched.permute(1, 0, 2).reshape(group_size, B * n_features).contiguous()  #  G, B*n]
    Xp_flat = torch.linalg.solve_triangular(U.transpose(0, 1), X_rhs, upper=False)       # solve(U^T, ·)
    Xp = Xp_flat.reshape(group_size, B, n_features).permute(1, 0, 2).contiguous()        #  B,G,n]

    Y_rhs = w_ori.to(dtype_acc).transpose(0, 1).contiguous()                              #  G,B]
    Yp_T = torch.linalg.solve_triangular(U.transpose(0, 1), Y_rhs, upper=False)           #  G,B]
    Yp = Yp_T.transpose(0, 1).contiguous()   

    XT = Xp.transpose(1, 2)     
    XTX = XT @ Xp 
    A = XTX + I * float(alpha)    
    rhs = XT @ Yp.unsqueeze(-1)

    # current_C_fp32 = torch.linalg.solve(A, rhs)                # [B, n_features, 1]
    sol, info = torch.linalg.solve_ex(A, rhs)                    
    if torch.any(info != 0):                     
        bad = (info != 0)                                        # [B]
        sol_bad = torch.linalg.lstsq(A[bad], rhs[bad]).solution  
        sol[bad] = sol_bad                                        
    if not torch.isfinite(sol).all():
        raise ValueError("Non-finite values (NaN or Inf) detected in solution C")

    error_groupwise_coeffs = analysis_diff(w_ori, (X_batched @ sol).squeeze(-1))
    print(f"*BPD-ALS (k={num_plane})* init error: {error_groupwise_coeffs:.6f}")
    
    return sol.to(dtype_store)



def quantize_bpdq_main(
    w_ori, 
    scales, 
    zeros, 

    W_tail_init, U_rows, d_g,

    bplane_bits=8, msb_num=4, alpha=1e-4,
    n_iters=5, candidate_vectors=None, candidate_vectors_f=None, 
    dtype_store=torch.float16, dtype_acc=torch.float32, U_g=None, maxq=-1, 
):
    # pid = os.getpid()
    # tid = threading.get_ident()

    maxq_val = int(maxq) if isinstance(maxq, int) else int(maxq.item())
    # q = torch.clamp(torch.round(w_ori / scales) + zeros, 0, maxq_val)
    q = torch.clamp(torch.round(w_ori / scales) + zeros, 0, maxq_val)


    # all_planes: [BITS, B, group_size]
    all_planes = bit_plane_decomposition(q, bplane_bits) 
    start_bit_index = bplane_bits - msb_num
    c_store_iter = solve_init_C(w_ori, all_planes[start_bit_index:], num_plane=msb_num, alpha=alpha, 
                                dtype_store=dtype_store, dtype_acc=dtype_acc, U_g=U_g)
    
    device = w_ori.device
    rows, G = w_ori.shape
    tail_len = W_tail_init.shape[1]

    base_tail = W_tail_init.to(dtype_acc).contiguous()

    best = None
    best_score = torch.tensor(float("inf"), device=device, dtype=dtype_acc)
    for it in range(n_iters):
        # restore tail to same start
        W_tail = base_tail.clone()

        # 1) build v_cands from c
        C_used = c_store_iter.to(dtype_acc)
        C_coeffs = C_used[:, :-1, 0].contiguous()   # [rows, k]
        C_bias   = C_used[:,  -1, 0].contiguous()   # [rows]
        v_cands  = C_coeffs.matmul(candidate_vectors_f.t()) + C_bias.unsqueeze(1)  # [rows, ncand]

        # 2) GPTQ-style B-step + propagation on tail
        k = msb_num
        B_final = torch.empty((k, rows, G), device=device, dtype=torch.uint8)
        Q_old   = torch.empty((rows, G), device=device, dtype=dtype_acc)
        S_old   = torch.empty((rows, G), device=device, dtype=dtype_acc)

        for t in range(G):
            y_col = W_tail[:, t]                                  # [rows]

            diff = y_col.unsqueeze(1) - v_cands                   # [rows, ncand]
            best_idx = diff.pow(2).argmin(dim=1)                  # [rows]

            best_b = candidate_vectors[best_idx]                  # [rows, k]
            B_final[:, :, t] = best_b.transpose(0, 1).contiguous()

            q_col = v_cands.gather(1, best_idx.view(-1, 1)).squeeze(1)  # [rows]
            Q_old[:, t] = q_col

            err = (y_col - q_col) / d_g[t]
            S_old[:, t] = err

            # propagate to tail: W_tail[:, t:] -= err * U_rows[t, t:]
            U_seg = U_rows[t, t:tail_len].to(dtype_acc)           # [tail_len - t]
            W_tail[:, t:tail_len] = W_tail[:, t:tail_len] - err.unsqueeze(1) * U_seg.unsqueeze(0)

            # keep exact q at position t
            W_tail[:, t] = q_col

        # 3) refit c under U_g metric with fixed B_final (prewhiten LS)
        c_store = solve_init_C(w_ori,                # use Y0_base 
                               B_final, msb_num, alpha, dtype_store, dtype_acc, U_g)

        # rebuild Q_new using inference-equivalent c_store
        c_used = c_store.to(dtype_acc)
        b_stacked = B_final.permute(1, 2, 0).to(dtype_acc)        # [rows, G, k]
        ones = torch.ones(rows, G, 1, device=device, dtype=dtype_acc)
        X = torch.cat([b_stacked, ones], dim=2)                   # [rows, G, k+1]
        Q_new = (X @ c_used).squeeze(-1)                          # [rows, G]

        # delta-correction: ΔS U_g = Q_old - Q_new
        D = (Q_old - Q_new)                                       # [rows, G]
        delta_S_T = torch.linalg.solve_triangular(U_g.transpose(0, 1).to(dtype_acc), D.transpose(0, 1).contiguous(),upper=False,)
        delta_S = delta_S_T.transpose(0, 1).contiguous()          # [rows, G]

        # update S and tail to match new Q
        S_new = S_old + delta_S

        # overwrite group with Q_new (final)
        W_tail[:, :G] = Q_new

        # correct tail beyond group: W_tail[:, G:] -= ΔS @ U_rows[:, G:]
        if tail_len > G:
            W_tail[:, G:tail_len] = W_tail[:, G:tail_len] - delta_S.matmul(U_rows[:, G:tail_len].to(dtype_acc))

        score = S_new.pow(2).sum()  # score by ||S_g||^2

        if (best is None) or (score < best_score): 
            best_score = score
            best = dict(W_tail=W_tail, Q_store=Q_new.to(dtype_store), S=S_new, B_final=B_final.clone(), c_store=c_store.clone())
            # print(f"[pid={pid}|tid={tid}] iter={it+1}/{n_iters}||best_score={best_score.detach().item():.6e}")

        # iterate c <- refit(c)
        c_store_iter = c_store

    return best


class Quantizer(nn.Module):
    def __init__(self, qcfg: QuantizeConfig, shape=1, name: str=None):
        super(Quantizer, self).__init__()

        self.qcfg = qcfg
        self.register_buffer("maxq", torch.tensor(0))
        self.register_buffer("scale", torch.zeros(shape))
        self.register_buffer("zero", torch.zeros(shape))

        self.name=name

    def quantize_bpdq(self, w_ori, scales, zeros, W_tail_init, U_rows, d_g, bplane_bits, msb_num,alpha, 
                       n_iters, candidate_vectors, candidate_vectors_f, dtype_store, dtype_acc, U_g):
        return quantize_bpdq_main(w_ori, scales, zeros, W_tail_init, U_rows, d_g, bplane_bits, msb_num, alpha, 
                                n_iters, candidate_vectors, candidate_vectors_f, dtype_store, dtype_acc, U_g=U_g, maxq=self.maxq)
    












































