####2025-05-06
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch

def remove_zeros_and_sort(values):
    """
    Extract non-zero entries from a 1D tensor, sort them by descending
    absolute magnitude, and return both the sorted absolute values and
    their original positions in the input tensor.

    Args:
        values (torch.Tensor): A one-dimensional tensor of numeric values.

    Returns:
        sorted_abs_vals (torch.Tensor):
            1D tensor of non-zero values, sorted by descending absolute value.
        global_indices (torch.Tensor):
            1D tensor of the corresponding indices in the original `values`
            tensor for each entry in `sorted_abs_vals`.
    """
    # Create a boolean mask identifying all non-zero elements
    nz_mask = (values != 0)

    # Retrieve the indices of non-zero elements in the flattened tensor
    nz_idx = nz_mask.nonzero(as_tuple=True)[0]

    # Gather the non-zero values
    nz_vals = values[nz_idx]

    # Compute absolute values for sorting by magnitude
    abs_vals = nz_vals.abs()

    # Sort absolute values in descending order, returning sorted values
    # and the permutation indices
    sorted_abs_vals, sort_perm = torch.sort(abs_vals, descending=True)

    # Map the sorted positions back to the original tensor indices
    global_indices = nz_idx[sort_perm]

    return sorted_abs_vals, global_indices


def example_nonzero_sort(pos_output, neg_output):
    """
    Demonstrate the extraction and sorting of non-zero elements for both
    positive and negative partitions of an intermediate feature tensor.

    This function applies `remove_zeros_and_sort` separately to:
    - pos_output: Tensor containing only non-negative entries (zeros or positive values).
    - neg_output: Tensor containing only non-positive entries (zeros or negative values).

    Args:
        pos_output (torch.Tensor): 1D tensor with positive values or zeros.
        neg_output (torch.Tensor): 1D tensor with negative values or zeros.

    Returns:
        pos_sorted_vals, pos_sorted_idx:
            Sorted absolute values and corresponding indices for the positive partition.
        neg_sorted_vals, neg_sorted_idx:
            Sorted absolute values and corresponding indices for the negative partition.
    """
    pos_sorted_vals, pos_sorted_idx = remove_zeros_and_sort(pos_output)
    neg_sorted_vals, neg_sorted_idx = remove_zeros_and_sort(neg_output)
    return pos_sorted_vals, pos_sorted_idx, neg_sorted_vals, neg_sorted_idx


def store_csr_non_cumulative(csr_tensor, values):
    """
    Decompose a PyTorch sparse CSR tensor into its constituent components
    without performing cumulative row-count conversion.

    Args:
        csr_tensor (torch.Tensor): A sparse tensor in CSR format.
        values (torch.Tensor): The non-zero values corresponding to the CSR
            representation (typically quantized values).

    Returns:
        dict: A dictionary containing:
            - "row_counts" (torch.Tensor): The CSR row pointer array of length nrows+1.
            - "col_indices" (torch.Tensor): The column indices array of length nnz.
            - "values" (torch.Tensor): The provided non-zero values.
            - "nrows" (int): Number of rows in the original dense tensor.
            - "ncols" (int): Number of columns in the original dense tensor.
    """
    # Extract the row pointer array (size: nrows + 1)
    crow = csr_tensor.crow_indices()

    # Extract the column index array (size: nnz)
    col = csr_tensor.col_indices()

    # Retrieve original tensor dimensions
    nrows, ncols = csr_tensor.size()

    return {
        "row_counts": crow,
        "col_indices": col,
        "values": values,
        "nrows": nrows,
        "ncols": ncols
    }

def restore_csr_from_non_cumulative(csr_info, device):
    """
    Reconstruct a PyTorch CSR (Compressed Sparse Row) tensor from its
    decomposed components without requiring cumulative row counts.

    Args:
        csr_info (dict): Dictionary containing:
            - "row_counts" (list or torch.Tensor): Row pointer array (length nrows+1).
            - "col_indices" (list or torch.Tensor): Column indices of non-zero elements.
            - "values" (list or torch.Tensor): Non-zero values.
            - "nrows" (int): Number of rows in the original dense matrix.
            - "ncols" (int): Number of columns in the original dense matrix.
        device (torch.device): Target device for the reconstructed tensor.

    Returns:
        torch.Tensor: Sparse CSR tensor of shape (nrows, ncols).
    """
    row_counts  = csr_info["row_counts"]
    col_indices = csr_info["col_indices"]
    values      = csr_info["values"]
    nrows       = csr_info["nrows"]
    ncols       = csr_info["ncols"]

    # Convert lists to tensors if necessary, ensuring correct dtype and device
    if isinstance(row_counts, list):
        row_counts = torch.tensor(row_counts, dtype=torch.int32, device=device)
    if isinstance(col_indices, list):
        col_indices = torch.tensor(col_indices, dtype=torch.int32, device=device)
    if isinstance(values, list):
        # Determine appropriate dtype for numeric values
        values = torch.tensor(values, dtype=torch.float32, device=device)

    # Construct the sparse CSR tensor directly from its components
    csr_tensor = torch.sparse_csr_tensor(
        row_counts,
        col_indices,
        values,
        size=(nrows, ncols),
        device=device
    )
    return csr_tensor


def AIQ(tensor, bits=8):
    """
    Perform Asymmetric Integer Quantization (AIQ) on the last dimension
    of an input tensor, computing scale and zero-point parameters.

    Args:
        tensor (torch.Tensor): Input tensor to be quantized.
        bits (int): Bit width for the quantization range.

    Returns:
        q_tensor (torch.Tensor): Quantized integer tensor (float dtype).
        scale (torch.Tensor): Scale factor used for quantization.
        zero_point (torch.Tensor): Zero-point offset for quantization.
    """
    qmin, qmax = 0.0, 2.0 ** bits - 1.0

    # Compute per-vector min and max values along the last dimension
    min_val, _ = torch.min(tensor, dim=-1, keepdim=True)
    max_val, _ = torch.max(tensor, dim=-1, keepdim=True)

    # Calculate scale and zero-point with small epsilon to avoid division by zero
    scale = (max_val - min_val) / (qmax - qmin + 1e-12)
    zero_point = qmin - min_val / (scale + 1e-12)

    # Apply the quantization transform: (tensor / scale + zero_point)
    q_tensor = tensor.div(scale).add_(zero_point)
    # Round to nearest integer and clamp to [qmin, qmax]
    q_tensor.round_().clamp_(qmin, qmax)
    return q_tensor, scale, zero_point


def DeAIQ(q_tensor, scale, zero_point):
    """
    De-quantize an integer tensor back to floating-point values using
    stored scale and zero-point.

    Args:
        q_tensor (torch.Tensor): Quantized tensor (float dtype).
        scale (torch.Tensor): Scale factor from AIQ.
        zero_point (torch.Tensor): Zero-point from AIQ.

    Returns:
        torch.Tensor: Reconstructed floating-point tensor.
    """
    return (q_tensor - zero_point) * scale


def DS(t1, t2, el):
    """
    Compute the mean absolute difference (Distortion Score) between two
    integer-valued tensors.

    Args:
        t1 (torch.Tensor): Reference integer tensor.
        t2 (torch.Tensor): Test integer tensor.
        el (int): Number of elements for normalization.

    Returns:
        torch.Tensor: Mean absolute error between t1 and t2.
    """
    return (t1 - t2).abs().sum().div(el)


def ABSQ_with_bit(tensor, q, delta, under_bound=1, USE=True):
    """
    Adaptive Bitwidth Search for quantization. Iteratively reduces bit-width
    until the distortion score exceeds a threshold.

    Args:
        tensor (torch.Tensor): Input tensor of values to quantize.
        q (int): Maximum bit-width to start quantization.
        delta (float): Maximum allowed distortion score.
        under_bound (int): Minimum bit-width to consider.
        USE (bool): If False, skip bitwidth adaptation.

    Returns:
        B_IQ (torch.Tensor): Final quantized tensor.
        scale (torch.Tensor): Scale factor from AIQ.
        zp (torch.Tensor): Zero-point from AIQ.
        bit_used (int): Bit-width selected after adaptation.
    """
    # Handle empty tensor: return defaults
    if tensor.numel() == 0:
        return (
            torch.empty_like(tensor),
            torch.tensor(1.0, dtype=tensor.dtype, device=tensor.device),
            torch.tensor(0.0, dtype=tensor.dtype, device=tensor.device),
            0
        )

    # Ensure lower bound on bit-width
    lower_bit = max(under_bound, q // 2)
    B_IQ, scale, zp = AIQ(tensor, q)

    # If adaptation disabled, return full bit-width quantization
    if not USE:
        return B_IQ, scale, zp, q

    numel = tensor.numel()
    # Try reducing bit-width one at a time
    for bit in range(q - 1, lower_bit - 1, -1):
        temp_B_IQ, _, _ = AIQ(tensor, bit)
        # Compare scaled reference to new quantization
        ref = B_IQ.div(2 ** (q - bit)).round()
        if DS(ref, temp_B_IQ, numel) >= delta:
            # Revert to previous bit-width if threshold exceeded
            return AIQ(tensor, bit + 1) + (bit + 1,)

    # If all reductions acceptable, use the lower bound
    return AIQ(tensor, lower_bit) + (lower_bit,)


def SF_exact(x: torch.Tensor, s: float, lambd: float = 0.0):
    """
    Exact Top-K sparsification: retain the largest (1-s)*N elements
    in absolute value and zero out the rest.

    Args:
        x (torch.Tensor): Input tensor (any shape).
        s (float): Fraction of elements to zero out (0 <= s <= 1).
        lambd (float): Unused placeholder for asymmetry.

    Returns:
        torch.Tensor: Sparsified tensor with same shape as x.
    """
    # Flatten input and determine number to keep
    x_flat = x.view(-1)
    N = x_flat.numel()
    k_keep = int((1 - s) * N)

    # Sort by absolute value descending
    abs_vals, sorted_idx = torch.sort(x_flat.abs(), descending=True)
    keep_idx = sorted_idx[:k_keep]

    # Build output: zeros except for top-k indices
    out_flat = torch.zeros_like(x_flat)
    out_flat[keep_idx] = x_flat[keep_idx]
    return out_flat.view_as(x)


def SF(x: torch.Tensor, s: float, lambd: float = 0.0):
    """
    Asymmetric Top-K filtering with optional threshold shift.
    Retains exactly (1-s)*N elements and zeros others.

    Args:
        x (torch.Tensor): Input tensor of any shape.
        s (float): Fraction of values to zero out (0 <= s <= 1).
        lambd (float): Asymmetry factor to bias positive vs. negative thresholds.

    Returns:
        torch.Tensor: Tensor with s*N zeros and same shape as input.
    """
    # Flatten for processing
    x_flat = x.view(-1)
    N = x_flat.numel()
    k_keep = int((1 - s) * N)

    # Handle trivial cases
    if k_keep <= 0:
        return torch.zeros_like(x)
    if k_keep >= N:
        return x.clone()

    # Identify top-k absolute values
    abs_flat = x_flat.abs()
    topk_vals, _ = torch.topk(abs_flat, k_keep, sorted=True)
    threshold = topk_vals[-1]

    # Compute asymmetric thresholds
    thr_pos = threshold * (1 + lambd)
    thr_neg = -threshold * (1 - lambd)

    # Select strictly above thresholds
    strict_mask = (x_flat > thr_pos) | (x_flat < thr_neg)
    strict_idx = strict_mask.nonzero(as_tuple=False).squeeze()
    num_strict = strict_idx.numel()

    # If additional slots needed, randomly sample ties at threshold
    num_needed = k_keep - num_strict
    if num_needed > 0:
        tie_mask = (abs_flat == threshold)
        tie_idx = tie_mask.nonzero(as_tuple=False).squeeze()
        if tie_idx.numel() == 1:
            chosen = tie_idx.unsqueeze(0)
        else:
            perm = torch.randperm(tie_idx.numel(), device=x.device)
            chosen = tie_idx[perm[:num_needed]]
        keep_idx = torch.cat([strict_idx, chosen], dim=0)
    else:
        keep_idx = strict_idx

    # Build output
    out_flat = torch.zeros_like(x_flat)
    out_flat[keep_idx] = x_flat[keep_idx]
    return out_flat.view_as(x)

def split_and_sort_full(splitted: torch.Tensor):
    """
    Partition a flattened tensor into positive and negative components,
    then sort non-zero entries of each by descending absolute value.

    Args:
        splitted (torch.Tensor): 1D tensor obtained after initial sparsification.

    Returns:
        pos_output (torch.Tensor):
            Tensor of same shape as `splitted`, with only original positive
            entries retained and negatives zeroed out.
        neg_output (torch.Tensor):
            Tensor of same shape as `splitted`, with only original negative
            entries (as absolute values) retained and positives zeroed out.
        pos_sorted_vals (torch.Tensor):
            1D tensor of non-zero absolute values from `pos_output`, sorted
            in descending order.
        pos_sorted_idx (torch.Tensor):
            Corresponding global indices in `splitted` for `pos_sorted_vals`.
        neg_sorted_vals (torch.Tensor):
            1D tensor of non-zero absolute values from `neg_output`, sorted
            in descending order.
        neg_sorted_idx (torch.Tensor):
            Corresponding global indices in `splitted` for `neg_sorted_vals`.
    """
    # Flatten to a 1D vector to simplify indexing
    flat = splitted.view(-1)
    
    # Create separate tensors for positive and negative parts
    pos_output = torch.clamp(flat, min=0.0)      # Keep positives, zero others
    neg_output = torch.clamp(-flat, min=0.0)     # Keep negatives as positive magnitudes

    # Compute absolute values for sorting all entries once
    abs_flat = flat.abs()
    sorted_abs_vals, sorted_idx = torch.sort(abs_flat, descending=True)

    # Retrieve the original signed values in sorted order
    sorted_vals = flat[sorted_idx]

    # Build boolean masks for positive and negative sorted entries
    pos_mask = sorted_vals > 0
    neg_mask = sorted_vals < 0

    # Select sorted absolute values and indices for each sign
    pos_sorted_vals = sorted_abs_vals[pos_mask]
    pos_sorted_idx  = sorted_idx[pos_mask]

    neg_sorted_vals = sorted_abs_vals[neg_mask]
    neg_sorted_idx  = sorted_idx[neg_mask]

    return pos_output, neg_output, \
           pos_sorted_vals, pos_sorted_idx, \
           neg_sorted_vals, neg_sorted_idx


def MS(split_output, sorted_tensor, sorted_indices, M, SHAPE):
    """
    Split non-zero entries into M balanced sub-tensors based on sorted indices,
    then convert each sub-tensor back into CSR format with original shape.

    Args:
        split_output (torch.Tensor):
            1D tensor with zeros at pruned positions.
        sorted_tensor (torch.Tensor):
            Sorted absolute values of non-zero entries.
        sorted_indices (torch.Tensor):
            Original indices corresponding to `sorted_tensor`.
        M (int):
            Number of blocks to partition non-zero values into.
        SHAPE (tuple):
            Target shape for reconstructing CSR sub-tensors.

    Returns:
        List[torch.Tensor]: List of M sparse CSR tensors, each of shape `SHAPE`.
    """
    if M == 0:
        return []

    length = sorted_tensor.numel()
    block_size = length // M
    blocks = []

    for i in range(M):
        # Determine index slice for this block
        if i < M - 1:
            idx_range = sorted_indices[i * block_size : (i + 1) * block_size]
        else:
            idx_range = sorted_indices[i * block_size :]

        # Create a dense zero tensor and fill in the values for this block
        sparse_dense = split_output.new_zeros(split_output.shape)
        sparse_dense[idx_range] = split_output[idx_range]

        # Reshape back to original IF dimensions and convert to CSR
        sparse_tensor = sparse_dense.reshape(SHAPE).to_sparse_csr()
        blocks.append(sparse_tensor)

    return blocks


def _check_dim_3_or_4(IF):
    """
    Validate that the input intermediate feature tensor has 3 or 4 dimensions.

    Args:
        IF (torch.Tensor): Input tensor.

    Raises:
        ValueError: If tensor dimensionality is not 3D or 4D.
    """
    if IF.dim() not in (3, 4):
        raise ValueError(
            f"SLICER functions support only 3D/4D tensors; got shape={tuple(IF.shape)}"
        )


def _adapt_dtype_for_compute(IF):
    """
    Ensure the input tensor is in floating-point format for computation.

    Args:
        IF (torch.Tensor): Input tensor, possibly boolean or integer type.

    Returns:
        adapted_tensor (torch.Tensor):
            Tensor cast to float32 if original was not floating-point.
        original_dtype (torch.dtype):
            The original data type of IF, for later restoration.
    """
    original_dtype = IF.dtype
    if not IF.is_floating_point():
        return IF.to(torch.float32), original_dtype
    return IF, original_dtype


def _cast_back_if_needed(tensor, original_dtype):
    """
    Restore a computed tensor to its original data type if necessary.

    Args:
        tensor (torch.Tensor): Tensor after computation, typically float32.
        original_dtype (torch.dtype): Desired output dtype.

    Returns:
        torch.Tensor: Tensor cast back to `original_dtype`, if different.
    """
    if tensor.dtype != original_dtype:
        return tensor.to(original_dtype)
    return tensor


def SLICER_only_SF(IF, sc_config):
    """
    Apply only the sparsification (SF) stage, preserving original shape and dtype.

    Args:
        IF (torch.Tensor): Input intermediate feature (3D/4D).
        sc_config (SCConfig): Configuration containing sparsity 's' and asymmetry 'lambd'.

    Returns:
        torch.Tensor: Sparsified tensor in original dtype.
    """
    _check_dim_3_or_4(IF)
    x, orig_dtype = _adapt_dtype_for_compute(IF)
    # Perform Asymmetric Top-K filtering
    out = SF(x, sc_config.s, sc_config.lambd)
    return _cast_back_if_needed(out.view_as(IF), orig_dtype)


def SLICER(IF, sc_config):
    """
    Complete edge-side compression pipeline with SF, MS, ABQ, and local reconstruction.

    Args:
        IF (torch.Tensor): Input intermediate feature (3D/4D).
        sc_config (SCConfig): Configuration with parameters:
            - s: sparsity ratio
            - lambd: asymmetry factor
            - Q, Q_n: lists of bit-widths for positive/negative partitions
            - delta: distortion threshold
            - use_ABSQ: flag to enable adaptive quantization
            - under_bound: minimum bit-width

    Returns:
        (output_tensor, bit_usage_list):
            output_tensor (torch.Tensor): Reconstructed feature for verification.
            bit_usage_list (List[int]): Bit-widths chosen for each partition block.
    """
    verbose = True
    _check_dim_3_or_4(IF)
    x, orig_dtype = _adapt_dtype_for_compute(IF)

    device = x.device
    ori_shape = x.shape

    # Extract parameters
    s, lambd = sc_config.s, sc_config.lambd
    Q, Q_neg = sc_config.Q, sc_config.Q_n
    delta, use_ABSQ = sc_config.delta, sc_config.use_ABSQ
    under_bd = sc_config.under_bound
    M_pos, M_neg = len(Q), len(Q_neg)

    # Flatten for SF
    x_flat = x.reshape(ori_shape[-1], -1)
    splitted = SF(x_flat, s, lambd)

    # Partition and sort non-zeros
    (pos_out, neg_out, pos_vals, pos_idx, neg_vals, neg_idx) = split_and_sort_full(splitted)
    pos_blocks = MS(pos_out, pos_vals, pos_idx, M_pos, x_flat.shape)
    neg_blocks = MS(neg_out, neg_vals, neg_idx, M_neg, x_flat.shape)

    # Quantize blocks and record CSR info
    pos_list, bit_used_all = [], []
    for i, blk in enumerate(pos_blocks):
        bits = Q[min(i, M_pos - 1)]
        B_IQ, scale, zp, used = ABSQ_with_bit(blk.values(), bits, delta, under_bd, use_ABSQ)
        info = store_csr_non_cumulative(blk, B_IQ)
        pos_list.append([info, used, scale, zp])
        bit_used_all.append(used)

    neg_list = []
    for i, blk in enumerate(neg_blocks):
        bits = Q_neg[min(i, M_neg - 1)]
        B_IQ, scale, zp, used = ABSQ_with_bit(blk.values(), bits, delta, under_bd, use_ABSQ)
        info = store_csr_non_cumulative(blk, B_IQ)
        neg_list.append([info, used, scale, zp])
        bit_used_all.append(used)

    # Local reconstruction for validation
    out = torch.zeros(ori_shape, dtype=x.dtype, device=device)
    for info, _, scale, zp in pos_list:
        info['values'] = DeAIQ(info['values'], scale, zp)
        c = restore_csr_from_non_cumulative(info, device)
        out.add_(c.to_dense().reshape(ori_shape))
    for info, _, scale, zp in neg_list:
        info['values'] = DeAIQ(info['values'], scale, zp)
        c = restore_csr_from_non_cumulative(info, device)
        out.sub_(c.to_dense().reshape(ori_shape))

    out = _cast_back_if_needed(out, orig_dtype)
    return (out, bit_used_all) if verbose else out

def SLICER_EDGE(IF, sc_config):
    _check_dim_3_or_4(IF)
    x, orig_dtype = _adapt_dtype_for_compute(IF)
    device = x.device
    shape = x.shape
    flat_len = x.numel()

    s        = sc_config.s
    lambd    = sc_config.lambd
    Q        = sc_config.Q
    Q_neg    = sc_config.Q_n
    delta    = sc_config.delta
    use_ABSQ = sc_config.use_ABSQ
    under_bd = sc_config.under_bound

    M_pos = len(Q)
    M_neg = len(Q_neg)

    # 1) SF
    x_flat = x.view(-1)
    abs_x = x_flat.abs()
    sorted_vals, _ = torch.sort(abs_x, descending=True)
    length = sorted_vals.shape[0]
    k = int(length*(1-s))
    k = min(k, length-1)
    if k<0:
        return {
            "shape": shape,
            "pos_list": [],
            "neg_list": [],
            "orig_dtype": orig_dtype
        }

    val = sorted_vals[k]
    thr_pos = val*(1+lambd)
    thr_neg = val*(lambd-1)

    splitted = torch.where((x_flat>thr_pos)|(x_flat<thr_neg), x_flat, torch.zeros_like(x_flat))

    pos_output = splitted.clone()
    pos_output[splitted<0] = 0
    neg_output = splitted.clone()
    neg_output[splitted>0] = 0
    neg_output = neg_output.abs_()

    pos_nonzero_mask = (pos_output!=0)
    pos_nz_idx = pos_nonzero_mask.nonzero(as_tuple=True)[0]
    pos_nz_vals= pos_output[pos_nz_idx]
    pos_abs_vals = pos_nz_vals.abs()
    pos_sorted_vals, pos_sort_tmp = torch.sort(pos_abs_vals, descending=True)
    pos_sorted_idx = pos_nz_idx[pos_sort_tmp]

    neg_nonzero_mask = (neg_output!=0)
    neg_nz_idx = neg_nonzero_mask.nonzero(as_tuple=True)[0]
    neg_nz_vals= neg_output[neg_nz_idx]
    neg_abs_vals= neg_nz_vals.abs()
    neg_sorted_vals, neg_sort_tmp = torch.sort(neg_abs_vals, descending=True)
    neg_sorted_idx = neg_nz_idx[neg_sort_tmp]

    # 3) MS
    pos_Tensors = MS(pos_output, pos_sorted_vals, pos_sorted_idx, M_pos, shape)
    neg_Tensors = MS(neg_output, neg_sorted_vals, neg_sorted_idx, M_neg, shape)

    # 4) ABSQ_with_bit -> CSR -> non-cumulative
    pos_list = []
    for i in range(M_pos):
        q_idx = min(i, len(Q)-1)
        B_IQ_f, scale_f, zp_f, bit_used = ABSQ_with_bit(pos_Tensors[i].values(), Q[q_idx], delta, under_bd, use_ABSQ)
        T_hat = torch.sparse_csr_tensor(
            pos_Tensors[i].crow_indices(),
            pos_Tensors[i].col_indices(),
            B_IQ_f,
            pos_Tensors[i].size()
        )
        T_hat_dict = store_csr_non_cumulative(T_hat)
        pos_list.append([T_hat_dict, bit_used, scale_f, zp_f])

    neg_list = []
    if M_neg>0:
        for i in range(M_neg):
            q_idx = min(i, len(Q_neg)-1)
            B_IQ_f, scale_f, zp_f, bit_used = ABSQ_with_bit(neg_Tensors[i].values(), Q_neg[q_idx], delta, under_bd, use_ABSQ)
            T_hat = torch.sparse_csr_tensor(
                neg_Tensors[i].crow_indices(),
                neg_Tensors[i].col_indices(),
                B_IQ_f,
                neg_Tensors[i].size()
            )
            T_hat_dict = store_csr_non_cumulative(T_hat)
            neg_list.append([T_hat_dict, bit_used, scale_f, zp_f])

    packet = {
        "shape": shape,
        "pos_list": pos_list,
        "neg_list": neg_list,
        "orig_dtype": orig_dtype
    }
    return packet


def recover_SLICER_EDGE(packet, device):

    shape = packet["shape"]
    pos_list = packet["pos_list"]
    neg_list = packet["neg_list"]
    orig_dtype = packet["orig_dtype"]

    out = torch.zeros(shape, dtype=torch.float32, device=device)

    for (csr_t_dict, b, scale, zp) in pos_list:
        c = restore_csr_from_non_cumulative(csr_t_dict,device)
        c.values().copy_(DeAIQ(c.values(), scale, zp))
        dense_val = c.to_dense()
        out.add_(dense_val.view(shape))

    for (csr_t_dict,b,  scale, zp) in neg_list:
        c = restore_csr_from_non_cumulative(csr_t_dict,device)
        c.values().copy_(DeAIQ(c.values(), scale, zp))
        dense_val = c.to_dense()
        out.sub_(dense_val.view(shape))

    #out = out.to(orig_dtype)
    return out

def cat_pos_list(comp_if,sign="pos_list"):
    pos_list = comp_if[sign]
    if len(pos_list) == 0:
        return torch.tensor([], dtype=torch.int32), []

    chunks = []
    metadata = []

    for i in range(len(pos_list)):
        T_hat_dict = pos_list[i][0]   # {'row_counts', 'col_indices', 'values', 'nrows', 'ncols'}
        bit_used   = pos_list[i][1]
        scale_f    = pos_list[i][2]
        zp_f       = pos_list[i][3]

        row_counts  = T_hat_dict['row_counts'].int()
        col_indices = T_hat_dict['col_indices'].int()
        values      = T_hat_dict['values'].int()
        nrows       = int(T_hat_dict['nrows'])
        ncols       = int(T_hat_dict['ncols'])

        len_rc  = row_counts.shape[0]
        len_ci  = col_indices.shape[0]
        len_val = values.shape[0]

        if isinstance(scale_f, torch.Tensor):
            scale_f = float(scale_f.item())
        if isinstance(zp_f, torch.Tensor):
            zp_f = float(zp_f.item())


        metadata.append(
            (len_rc, len_ci, len_val, nrows, ncols, bit_used, scale_f, zp_f)
        )


        chunks.append(row_counts)
        chunks.append(col_indices)
        chunks.append(values)

    concat_tensor = torch.cat(chunks, dim=0).int()

    return concat_tensor, metadata



def split_pos_list(concat_tensor, metadata):

    csr_list = []
    offset = 0

    for i in range(len(metadata)):
        (len_rc, len_ci, len_val,
         nrows, ncols, bit_used,
         scale_val, zp_val) = metadata[i]

        row_counts  = concat_tensor[offset : offset + len_rc]
        offset     += len_rc
        col_indices = concat_tensor[offset : offset + len_ci]
        offset     += len_ci
        values      = concat_tensor[offset : offset + len_val]
        offset     += len_val

        T_hat_dict = {
            "row_counts":  row_counts,
            "col_indices": col_indices,
            "values":      values,
            "nrows":       nrows,
            "ncols":       ncols
        }
        reconstructed_partition = [T_hat_dict, bit_used, scale_val, zp_val]

        csr_list.append(reconstructed_partition)

    return csr_list



def IMG_intermeidate(model, sc_config, loader, shape, device, slicer_ftn, samples, save_dir):


    import os
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model.eval()

    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += 1

            IF = model.split_edge_output(inputs, sc_config.split_layer)

            if slicer_ftn is not None:
                IF = IF.reshape(shape)
                IF = slicer_ftn(IF, sc_config)
                concat_tensor_pos, metadata_pos = cat_pos_list(IF,"pos_list")
                concat_tensor_neg, metadata_neg = cat_pos_list(IF,"neg_list")
            else:
                concat_tensor = IF.view(-1).int()

            concat_tensor = torch.cat([concat_tensor_pos,concat_tensor_neg])
            
            concat_tensor_cpu = concat_tensor.detach().cpu().tolist()

            save_path = os.path.join(save_dir, f"{batch_idx}.txt")

            with open(save_path, 'w') as f:
                for val in concat_tensor_cpu:
                    f.write(f"{val}\n")

            print(f"Saved txt file at: {save_path}")

            if total >= samples:
                break
