"""
Weight compression utilities for low-rank approximation and quantization.

This module provides functions for applying low-rank approximation to neural network weights
using various methods including SVD, power iteration, and quantized residuals.
"""

import torch
import torch.nn.functional as F
import logging
import gc
from typing import Optional, Dict, Any

try:
    from .hadamard import Transform_Dict, next_power_of_2
    from .tensor_visualizer import visualize_tensor_3d
except ImportError:
    try:
        # Fallback import
        import sys
        import os
        sys.path.append(os.path.join(os.path.dirname(__file__), 'GEARLM', 'Simulated'))
        from hadamard import Transform_Dict, next_power_of_2
        from tensor_visualizer import visualize_tensor_3d
    except ImportError:
        # If hadamard is not available, raise an informative error
        raise ImportError("Hadamard transform requires the hadamard module from GEARLM.Simulated")

transform_cache = Transform_Dict()
logger = logging.getLogger(__name__)


def remove_outliers_2d(tensor: torch.Tensor, sparsity: float) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
    """
    Remove outliers from a 2D tensor by replacing smallest and largest values with row averages.
    
    Args:
        tensor: Input 2D tensor to process
        sparsity: Sparsity ratio for outlier removal (0.0 to 1.0)
        
    Returns:
        Tuple of (processed_tensor, smallest_value, smallest_indices, largest_value, largest_indices)
        If sparsity is 0 or no outliers are removed, the last 4 elements will be None
    """
    if sparsity <= 0.0 or tensor.dim() != 2:
        return tensor.clone(), None, None, None, None
    
    # Calculate number of outliers to remove for 2D tensor
    element_num = tensor.numel()
    sparsity_num = int(element_num * sparsity)
    sparsity_per_dim = int(sparsity_num / tensor.shape[0] / 2)  # Divide by 2 for smallest and largest
    
    if sparsity_per_dim <= 0:
        return tensor.clone(), None, None, None, None
    
    # Create a copy to avoid modifying the original tensor
    processed_tensor = tensor.clone()
    
    # Find smallest and largest values along the last dimension
    smallest_value, smallest_indices = torch.topk(processed_tensor, sparsity_per_dim, dim=-1, largest=False)
    largest_value, largest_indices = torch.topk(processed_tensor, sparsity_per_dim, dim=-1, largest=True)
    
    # Calculate average for replacement
    average = processed_tensor.mean(dim=-1, keepdim=True)
    expanded_average = average.expand_as(processed_tensor)
    
    # Replace outliers with average values
    processed_tensor.scatter_(-1, smallest_indices, expanded_average.gather(-1, smallest_indices))
    processed_tensor.scatter_(-1, largest_indices, expanded_average.gather(-1, largest_indices))
    
    return processed_tensor, smallest_value, smallest_indices, largest_value, largest_indices


def restore_outliers_2d(tensor: torch.Tensor, smallest_value: Optional[torch.Tensor], smallest_indices: Optional[torch.Tensor], 
                        largest_value: Optional[torch.Tensor], largest_indices: Optional[torch.Tensor]) -> torch.Tensor:
    """
    Restore original outlier values to a 2D tensor.
    
    Args:
        tensor: Input tensor to restore outliers to
        smallest_value: Original smallest values to restore
        smallest_indices: Indices where smallest values should be restored
        largest_value: Original largest values to restore  
        largest_indices: Indices where largest values should be restored
        
    Returns:
        Tensor with outliers restored
    """
    if (smallest_indices is None or largest_indices is None or 
        smallest_value is None or largest_value is None):
        return tensor
    
    # Create a copy to avoid modifying the original tensor
    restored_tensor = tensor.clone()
    
    # Restore the original values at the smallest and largest indices
    restored_tensor.scatter_(-1, smallest_indices, smallest_value)
    restored_tensor.scatter_(-1, largest_indices, largest_value)
    
    return restored_tensor


def per_channel_symmetric_quantization(tensor: torch.Tensor, num_bits: int = 8) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Apply per-channel symmetric quantization to a 2D tensor.
    
    Args:
        tensor: Input tensor to quantize (2D)
        num_bits: Number of bits for quantization (default: 8)
        
    Returns:
        Tuple of (quantized_tensor, scale_factors)
    """
    if tensor.dim() != 2:
        raise ValueError(f"Tensor must be 2D for per-channel quantization, got {tensor.dim()}D")
    
    device = tensor.device
    dtype = tensor.dtype
    
    # Calculate per-channel (per-row) scale factors
    # For symmetric quantization: scale = max(abs(tensor)) / (2^(bits-1) - 1)
    max_vals = torch.max(torch.abs(tensor), dim=1, keepdim=True)[0]  # [rows, 1]
    qmax = 2**(num_bits - 1) - 1  # e.g., 127 for 8-bit
    
    # Avoid division by zero
    scale = max_vals / qmax
    scale = torch.where(scale == 0, torch.tensor(1.0, device=device), scale)
    
    # Quantize
    quantized = torch.round(tensor / scale).clamp(-qmax - 1, qmax)  # e.g., [-128, 127] for 8-bit
    
    # Dequantize
    dequantized = quantized * scale
    
    return dequantized.to(dtype), scale.to(dtype)


def per_channel_asymmetric_quantization(tensor: torch.Tensor, num_bits: int = 8) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Apply per-channel asymmetric (affine) quantization to a 2D tensor.
    
    Args:
        tensor: Input tensor to quantize (2D)
        num_bits: Number of bits for quantization (default: 8)
        
    Returns:
        Tuple of (dequantized_tensor, scale_factors, zero_points)
        - dequantized_tensor: dequantized float tensor with original dtype
        - scale_factors: per-channel scales (shape [rows, 1])
        - zero_points: per-channel zero-points (int32, shape [rows, 1])
    """
    if tensor.dim() != 2:
        raise ValueError(f"Tensor must be 2D for per-channel quantization, got {tensor.dim()}D")
    
    device = tensor.device
    dtype = tensor.dtype
    
    # Per-channel min/max (per row)
    min_vals = torch.min(tensor, dim=1, keepdim=True)[0]
    max_vals = torch.max(tensor, dim=1, keepdim=True)[0]
    
    # Asymmetric quantization range (unsigned)
    qmin = 0
    qmax = 2**num_bits - 1
    denom = float(qmax - qmin)
    
    # Scale and zero-point
    scale = (max_vals - min_vals) / denom
    scale = torch.where(scale == 0, torch.tensor(1.0, device=device, dtype=tensor.dtype), scale)
    zero_point = qmin - (min_vals / scale)
    zero_point = torch.round(zero_point).clamp(qmin, qmax)
    
    # Quantize → Dequantize
    quantized = torch.round(tensor / scale + zero_point).clamp(qmin, qmax)
    dequantized = (quantized - zero_point) * scale
    
    return dequantized.to(dtype), scale.to(dtype), zero_point.to(torch.int32)

def fake_groupwise_channel_fp8_quantization_2d(
    input: torch.Tensor, fp8_format="e4m3", group_size=128, max_chunk_size=1000000
):
    """
    FP8 quantization with group-wise scale operations for 2D weight tensors.
    Quantizes values in groups using FP8 format with per-group scaling.
    Note: Subnormal values are flushed to zero.
    
    Args:
        input: Input tensor of shape [in_features, out_features]
        fp8_format: FP8 format, either "e4m3" or "e5m2"
        group_size: Size of each group for group-wise quantization
        max_chunk_size: Maximum number of elements to process at once for memory efficiency
    
    Returns:
        Dequantized tensor in original dtype
    """
    in_features, out_features = input.shape
    dtype = input.dtype
    
    # FP8 format definitions
    if fp8_format == "e4m3":
        # E4M3: 1 sign bit, 4 exponent bits, 3 mantissa bits
        exp_bits = 4
        mantissa_bits = 3
        bias = 7  # 2^(4-1) - 1
        # max_normal = (2 - 2**-3) * 2**7 = 1.875 * 128 = 240.0
        max_normal = 240.0
        min_normal = 2**(-6)  # smallest positive normal number
    elif fp8_format == "e5m2":
        # E5M2: 1 sign bit, 5 exponent bits, 2 mantissa bits
        exp_bits = 5
        mantissa_bits = 2
        bias = 15  # 2^(5-1) - 1
        max_normal = 57344.0  # largest normal number in E5M2
        min_normal = 2**(-14)  # smallest positive normal number
    else:
        raise ValueError(f"Unsupported FP8 format: {fp8_format}")
    
    # Check if we need chunked processing for memory efficiency
    total_elements = in_features * out_features
    
    if total_elements <= max_chunk_size:
        # Process normally for small tensors
        return _process_fp8_quantization_chunk(input, fp8_format, group_size, max_normal, min_normal, mantissa_bits)
    else:
        # Process in chunks for large tensors
        input_flat = input.reshape(-1)
        chunk_size = max_chunk_size
        num_chunks = (total_elements + chunk_size - 1) // chunk_size
        
        result_chunks = []
        
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, total_elements)
            chunk = input_flat[start_idx:end_idx]
            
            # Reshape chunk to 2D for processing (make it roughly square)
            chunk_elements = end_idx - start_idx
            chunk_rows = int(chunk_elements ** 0.5) + 1
            chunk_cols = (chunk_elements + chunk_rows - 1) // chunk_rows
            
            # Pad chunk if necessary
            if chunk_elements < chunk_rows * chunk_cols:
                padding = torch.zeros(chunk_rows * chunk_cols - chunk_elements, 
                                    device=input.device, dtype=input.dtype)
                chunk = torch.cat([chunk, padding])
            
            chunk_2d = chunk.reshape(chunk_rows, chunk_cols)
            
            # Process chunk
            processed_chunk = _process_fp8_quantization_chunk(
                chunk_2d, fp8_format, group_size, max_normal, min_normal, mantissa_bits
            )
            
            # Flatten and remove padding
            processed_flat = processed_chunk.reshape(-1)
            if chunk_elements < chunk_rows * chunk_cols:
                processed_flat = processed_flat[:chunk_elements]
            
            result_chunks.append(processed_flat)
            
            # Clean up
            del chunk, chunk_2d, processed_chunk, processed_flat
            
            # Force garbage collection every few chunks to free memory
            if (i + 1) % 5 == 0:
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                gc.collect()
            
        # Concatenate results and reshape
        result_flat = torch.cat(result_chunks)
        result = result_flat.reshape(in_features, out_features)
        
        # Clean up
        del input_flat, result_chunks, result_flat
        
        # Final memory cleanup
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        gc.collect()
        
        return result.type(dtype)


def _process_fp8_quantization_chunk(input: torch.Tensor, fp8_format: str, group_size: int, 
                                   max_normal: float, min_normal: float, mantissa_bits: int) -> torch.Tensor:
    """Helper function to process a chunk of FP8 quantization"""
    in_features, out_features = input.shape
    total_elements = in_features * out_features
    input_flat = input.reshape(-1)
    
    # Pad if necessary to make divisible by group_size
    remainder = total_elements % group_size
    if remainder != 0:
        padding_size = group_size - remainder
        input_flat = torch.cat([input_flat, torch.zeros(padding_size, device=input.device, dtype=input.dtype)])
        total_elements += padding_size
    
    group_num = total_elements // group_size
    
    # Reshape into groups
    fixed_input = input_flat.view(group_num, group_size)
    float_input = fixed_input.float()
    
    # Calculate per-group scale factors
    max_abs_vals = torch.max(torch.abs(float_input), dim=1, keepdim=True)[0]
    epsilon = 1e-8
    group_scale = max_normal / (max_abs_vals + epsilon)
    
    # Apply group-wise scaling
    scaled_input = float_input * group_scale
    
    # Apply FP8 quantization
    clipped_input = torch.clamp(scaled_input, -max_normal, max_normal)
    zero_mask = torch.abs(clipped_input) < min_normal
    dequantized_scaled = torch.zeros_like(clipped_input)
    
    non_zero_mask = ~zero_mask
    if torch.any(non_zero_mask):
        non_zero_values = clipped_input[non_zero_mask]
        sign_values = torch.sign(non_zero_values)
        abs_values = torch.abs(non_zero_values)
        
        log2_val = torch.log2(abs_values)
        exp_unbiased = torch.floor(log2_val)
        
        mantissa_scale_fp8 = 2**mantissa_bits
        normalized_mantissa = abs_values / (2**exp_unbiased)
        mantissa_fraction = (normalized_mantissa - 1.0) * mantissa_scale_fp8
        quantized_mantissa = torch.round(mantissa_fraction).clamp(0, mantissa_scale_fp8 - 1)
        
        reconstructed_mantissa = 1.0 + quantized_mantissa / mantissa_scale_fp8
        dequantized_abs = reconstructed_mantissa * (2**exp_unbiased)
        dequantized_non_zero = sign_values * dequantized_abs
        
        dequantized_scaled[non_zero_mask] = dequantized_non_zero
        
        # Clean up intermediate tensors
        del non_zero_values, sign_values, abs_values, log2_val, exp_unbiased
        del normalized_mantissa, mantissa_fraction, quantized_mantissa
        del reconstructed_mantissa, dequantized_abs, dequantized_non_zero
    
    # Apply inverse scaling
    dequantized_input = dequantized_scaled / group_scale
    
    # Clean up more intermediate tensors
    del dequantized_scaled, group_scale, scaled_input, clipped_input
    del zero_mask, non_zero_mask, max_abs_vals, float_input, fixed_input
    
    # Reshape back to original shape
    dequantized_flat = dequantized_input.view(-1)
    if remainder != 0:
        dequantized_flat = dequantized_flat[:in_features * out_features]
    
    result = dequantized_flat.view(in_features, out_features)
    
    # Clean up final intermediate tensors
    del dequantized_input, dequantized_flat, input_flat
    
    return result


def fake_groupwise_channel_fp4_quantization_2d(
    input: torch.Tensor, fp4_format="e2m1", group_size=128, max_chunk_size=1000000
):
    """
    FP4 quantization with group-wise scale operations for 2D weight tensors.
    Quantizes values in groups using FP4 format with per-group scaling.
    Note: Subnormal values are flushed to zero.
    
    Args:
        input: Input tensor of shape [in_features, out_features]
        fp4_format: FP4 format, either "e2m1" or "e3m0"
        group_size: Size of each group for group-wise quantization
        max_chunk_size: Maximum number of elements to process at once for memory efficiency
    
    Returns:
        Dequantized tensor in original dtype
    """
    in_features, out_features = input.shape
    dtype = input.dtype
    
    # FP4 format definitions
    if fp4_format == "e2m1":
        # E2M1: 1 sign bit, 2 exponent bits, 1 mantissa bit
        exp_bits = 2
        mantissa_bits = 1
        bias = 1  # 2^(2-1) - 1 = 1
        # max_normal = (2 - 2**-1) * 2**1 = 1.5 * 2 = 3.0
        max_normal = 3.0
        min_normal = 2**(-1)  # smallest positive normal number = 0.5
    elif fp4_format == "e3m0":
        # E3M0: 1 sign bit, 3 exponent bits, 0 mantissa bits
        exp_bits = 3
        mantissa_bits = 0
        bias = 3  # 2^(3-1) - 1 = 3
        # max_normal = 1.0 * 2**3 = 8.0 (no mantissa, so coefficient is 1.0)
        max_normal = 8.0
        min_normal = 2**(-2)  # smallest positive normal number = 0.25
    else:
        raise ValueError(f"Unsupported FP4 format: {fp4_format}")
    
    # Check if we need chunked processing for memory efficiency
    total_elements = in_features * out_features
    
    if total_elements <= max_chunk_size:
        # Process normally for small tensors
        return _process_fp4_quantization_chunk(input, fp4_format, group_size, max_normal, min_normal, mantissa_bits)
    else:
        # Process in chunks for large tensors
        input_flat = input.reshape(-1)
        chunk_size = max_chunk_size
        num_chunks = (total_elements + chunk_size - 1) // chunk_size
        
        result_chunks = []
        
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, total_elements)
            chunk = input_flat[start_idx:end_idx]
            
            # Reshape chunk to 2D for processing (make it roughly square)
            chunk_elements = end_idx - start_idx
            chunk_rows = int(chunk_elements ** 0.5) + 1
            chunk_cols = (chunk_elements + chunk_rows - 1) // chunk_rows
            
            # Pad chunk if necessary
            if chunk_elements < chunk_rows * chunk_cols:
                padding = torch.zeros(chunk_rows * chunk_cols - chunk_elements, 
                                    device=input.device, dtype=input.dtype)
                chunk = torch.cat([chunk, padding])
            
            chunk_2d = chunk.reshape(chunk_rows, chunk_cols)
            
            # Process chunk
            processed_chunk = _process_fp4_quantization_chunk(
                chunk_2d, fp4_format, group_size, max_normal, min_normal, mantissa_bits
            )
            
            # Flatten and remove padding
            processed_flat = processed_chunk.reshape(-1)
            if chunk_elements < chunk_rows * chunk_cols:
                processed_flat = processed_flat[:chunk_elements]
            
            result_chunks.append(processed_flat)
            
            # Clean up
            del chunk, chunk_2d, processed_chunk, processed_flat
            
            # Force garbage collection every few chunks to free memory
            if (i + 1) % 5 == 0:
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                gc.collect()
            
        # Concatenate results and reshape
        result_flat = torch.cat(result_chunks)
        result = result_flat.reshape(in_features, out_features)
        
        # Clean up
        del input_flat, result_chunks, result_flat
        
        # Final memory cleanup
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        gc.collect()
        
        return result.type(dtype)


def _process_fp4_quantization_chunk(input: torch.Tensor, fp4_format: str, group_size: int, 
                                   max_normal: float, min_normal: float, mantissa_bits: int) -> torch.Tensor:
    """Helper function to process a chunk of FP4 quantization"""
    in_features, out_features = input.shape
    total_elements = in_features * out_features
    input_flat = input.reshape(-1)
    
    # Pad if necessary to make divisible by group_size
    remainder = total_elements % group_size
    if remainder != 0:
        padding_size = group_size - remainder
        input_flat = torch.cat([input_flat, torch.zeros(padding_size, device=input.device, dtype=input.dtype)])
        total_elements += padding_size
    
    group_num = total_elements // group_size
    
    # Reshape into groups
    fixed_input = input_flat.view(group_num, group_size)
    float_input = fixed_input.float()
    
    # Calculate per-group scale factors
    max_abs_vals = torch.max(torch.abs(float_input), dim=1, keepdim=True)[0]
    epsilon = 1e-8
    group_scale = max_normal / (max_abs_vals + epsilon)
    
    # Apply group-wise scaling
    scaled_input = float_input * group_scale
    
    # Apply FP4 quantization
    clipped_input = torch.clamp(scaled_input, -max_normal, max_normal)
    zero_mask = torch.abs(clipped_input) < min_normal
    dequantized_scaled = torch.zeros_like(clipped_input)
    
    non_zero_mask = ~zero_mask
    if torch.any(non_zero_mask):
        non_zero_values = clipped_input[non_zero_mask]
        sign_values = torch.sign(non_zero_values)
        abs_values = torch.abs(non_zero_values)
        
        log2_val = torch.log2(abs_values)
        exp_unbiased = torch.floor(log2_val)
        
        if mantissa_bits > 0:
            # Standard FP4 E2M1 format with mantissa
            mantissa_scale_fp4 = 2**mantissa_bits
            normalized_mantissa = abs_values / (2**exp_unbiased)
            mantissa_fraction = (normalized_mantissa - 1.0) * mantissa_scale_fp4
            quantized_mantissa = torch.round(mantissa_fraction).clamp(0, mantissa_scale_fp4 - 1)
            
            reconstructed_mantissa = 1.0 + quantized_mantissa / mantissa_scale_fp4
            dequantized_abs = reconstructed_mantissa * (2**exp_unbiased)
        else:
            # E3M0 format with no mantissa (integer powers of 2)
            # Round exponent to nearest integer and reconstruct
            exp_rounded = torch.round(log2_val)
            dequantized_abs = 2**exp_rounded
        
        dequantized_non_zero = sign_values * dequantized_abs
        dequantized_scaled[non_zero_mask] = dequantized_non_zero
        
        # Clean up intermediate tensors
        del non_zero_values, sign_values, abs_values, log2_val, exp_unbiased
        if mantissa_bits > 0:
            del normalized_mantissa, mantissa_fraction, quantized_mantissa, reconstructed_mantissa
        else:
            del exp_rounded
        del dequantized_abs, dequantized_non_zero
    
    # Apply inverse scaling
    dequantized_input = dequantized_scaled / group_scale
    
    # Clean up more intermediate tensors
    del dequantized_scaled, group_scale, scaled_input, clipped_input
    del zero_mask, non_zero_mask, max_abs_vals, float_input, fixed_input
    
    # Reshape back to original shape
    dequantized_flat = dequantized_input.view(-1)
    if remainder != 0:
        dequantized_flat = dequantized_flat[:in_features * out_features]
    
    result = dequantized_flat.view(in_features, out_features)
    
    # Clean up final intermediate tensors
    del dequantized_input, dequantized_flat, input_flat
    
    return result


def power_iteration_2d(weight: torch.Tensor, actual_rank: int, loop: int = 3) -> torch.Tensor:
    """
    Apply power iteration method for 2D matrix low-rank approximation.
    Adapted from fake_poweriteration_group for 2D matrices.
    
    Args:
        weight: Weight tensor to approximate (2D) [dim1, dim2]
        rank_percentage: Percentage of rank to keep (0-100)
        loop: Number of power iteration loops
        
    Returns:
        Low-rank approximated weight tensor
    """
    if weight.dim() != 2:
        raise ValueError(f"Weight tensor must be 2D, got {weight.dim()}D")
    
    device = weight.device
    dtype = weight.dtype
    dim1, dim2 = weight.shape
    
    # Convert to float for computation
    input_tensor = weight.float()
    
    # Initialize random bases
    p_base = torch.rand(dim2, actual_rank).to(device)
    q_base = torch.rand(dim1, actual_rank).to(device)
    
    # Power iteration loop
    for i in range(loop):
        if i == loop - 1:
            p_base = torch.linalg.qr(p_base).Q
        q_base = input_tensor @ p_base
        if i == loop - 1:
            q_base = torch.linalg.qr(q_base).Q
        p_base = input_tensor.T @ q_base
    
    # Reconstruct low-rank approximation
    low_rank_weight = q_base @ p_base.T
    
    # Move back to original device and dtype
    return low_rank_weight.to(device=device, dtype=dtype)


def svd_approx(weight: torch.Tensor, rank: int, mode: str = "svd", loop: int = 3) -> torch.Tensor:
    # Standard low-rank approximation without transform
    if mode == "approx":
        svd_approximation = power_iteration_2d(weight, rank, loop)
    elif mode == "svd":
        # Perform SVD
        U, S, Vt = torch.svd(weight)
        Vt = Vt.mT
        
        # Truncate to specified rank
        U_truncated = U[:, :rank]
        S_truncated = S[:rank]
        Vt_truncated = Vt[:rank, :]
        
        # Reconstruct SVD approximation
        svd_approximation = U_truncated @ torch.diag(S_truncated) @ Vt_truncated
    else:
        raise ValueError(f"Unknown mode: {mode}. Choose 'svd' or 'approx'.")
    return svd_approximation


def Hadamard_channel_func(weight: torch.Tensor, target_func) -> torch.Tensor:
    device = weight.device
    dtype = weight.dtype
    
    # Step 1: Padding - Use the last dimension for transform
    sep_dim = weight.shape[1]
    size_m = next_power_of_2(sep_dim)

    if size_m <= 16384:
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_dim)).to(device, dtype=dtype)
        
        # Step 2: Hadamard Transform - Apply padding and transform
        padded_weight = F.pad(weight, (0, size_m - sep_dim))
        transformed_weight = padded_weight @ H
    else:
        transformed_weight = weight
    
    # Step 3: SVD - Apply low-rank approximation in transformed space
    transformed_lr = target_func(transformed_weight)
    
    if size_m <= 16384:
        # Step 4: Inverse Hadamard Transform
        approximation = (transformed_lr @ H) / size_m
        
        # Step 5: Truncation (Remove padding)
        approximation = approximation[:, :sep_dim]
    else:
        approximation = transformed_lr
    
    return approximation

def Hadamard_token_func(weight: torch.Tensor, target_func) -> torch.Tensor:
    device = weight.device
    dtype = weight.dtype
    
    # Step 1: Padding - Use the first dimension for transform
    sep_len = weight.shape[0]
    size_m = next_power_of_2(sep_len)

    if size_m <= 16384:
        H = torch.from_numpy(transform_cache.get_or_register("hadamard", sep_len)).to(device, dtype=dtype)
        
        # Step 2: Hadamard Transform - Apply padding and transform
        padded_weight = F.pad(weight, (0, 0, 0, size_m - sep_len))
        transformed_weight = H @ padded_weight
    else:
        transformed_weight = weight
    
    # Step 3: SVD - Apply low-rank approximation in transformed space
    transformed_lr = target_func(transformed_weight)
    
    if size_m <= 16384:
        # Step 4: Inverse Hadamard Transform
        approximation = (H @ transformed_lr) / size_m
        
        # Step 5: Truncation (Remove padding)
        approximation = approximation[:sep_len, :]
    else:
        approximation = transformed_lr
    
    return approximation


def apply_low_rank_approximation(weight: torch.Tensor, rank_percentage: float, mode: str = "svd", loop: int = 3, 
                               use_quantized_residual: bool = False, weight_quant_bits: int = 8, transform: str = "none", 
                               sparsity: float = 0.0, activation_scale: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, int, dict]:
    """
    Apply low-rank approximation to a weight tensor using either SVD or power iteration.
    Optionally uses quantized residual (SVD + Q(W - SVD(W))) for better approximation.
    
    Args:
        weight: Weight tensor to approximate (2D)
        rank_percentage: Percentage of rank to keep (0-100)
        mode: 'svd' for exact SVD, 'approx' for power iteration
        loop: Number of power iteration loops (only used for 'approx' mode)
        use_quantized_residual: Whether to use quantized residual for better approximation
        weight_quant_bits: Number of bits for residual quantization
        transform: Transform to apply before low-rank approximation ('none', 'hadamard', 'pca')
        sparsity: Sparsity ratio for outlier removal (0.0 to 1.0)
        activation_scale: Optional activation scale tensor for AWQ-style quantization
        
    Returns:
        Tuple of (low-rank approximated weight tensor, actual rank used, compression_info)
    """

    if weight.dim() != 2:
        raise ValueError(f"Weight tensor must be 2D, got {weight.dim()}D")
    
    device = weight.device
    dtype = weight.dtype
    
    # Calculate actual rank from percentage
    max_rank = max(weight.shape) if max(weight.shape) < 16384 else min(weight.shape)
    actual_rank = max(1, int(max_rank * rank_percentage / 100.0))

    # Rank multiplication
    actual_rank = actual_rank * 4

    # Create quantization function for AWQ
    def weight_quantize_func(tensor):
        quantized, _, _ = per_channel_asymmetric_quantization(tensor, weight_quant_bits)
        return quantized
    
    # Create low-rank approximation function
    def low_rank_func(tensor):
        return svd_approx(tensor, actual_rank, mode, loop)
    
    compression_info = {
        'use_quantized_residual': use_quantized_residual,
        'weight_quant_bits': weight_quant_bits if use_quantized_residual else None,
        'residual_error': None,
        'transform': transform
    }
    
    # Convert weight to float for computation
    weight_float = weight.float()

    # Apply activation-aware scaling if available
    if activation_scale is not None:
        # Import AWQ functions
        try:
            from .awq import find_optimal_scale_for_weight
        except ImportError:
            from awq import find_optimal_scale_for_weight
        
        # Find optimal scale using activation information
        optimal_scale, best_error = find_optimal_scale_for_weight(
            weight_float, activation_scale, weight_quantize_func, 
            low_rank_func, actual_rank, n_grid=20
        )
        
        if best_error <= 1.0:
        # Apply optimal scale to weight
            optimal_scale = optimal_scale.view(1, -1).to(device).to(weight_float.dtype)
            weight_float = weight_float * optimal_scale
        else:
            logger.warning(f"AWQ aborted due to high error {best_error:.6f}")
            optimal_scale = torch.ones_like(activation_scale)
    
    # Standard path without AWQ
    # Remove outliers (if sparsity > 0)
    # weight_float, smallest_value_svd, smallest_indices_svd, largest_value_svd, largest_indices_svd = remove_outliers_2d(weight_float, sparsity)

    # Apply transform before low-rank approximation
    if transform == "hadamard":
        # Use Hadamard SVD function
        svd_approximation = Hadamard_channel_func(weight_float, low_rank_func)
        
    elif transform == "pca":
        # PCA transform: Transform - SVD - Inverse Transform
        # Compute covariance matrix for PCA

        if weight_float.shape[0] >= weight_float.shape[1]:
            cov_matrix = weight_float.mT @ weight_float
            
            # Real SVD for PCA
            U, sig, UT = torch.linalg.svd(cov_matrix, full_matrices=False)
            
            # Transform to PCA space
            transformed_weight = weight_float @ UT
            
            # Apply low-rank approximation in PCA space
            transformed_lr = svd_approx(transformed_weight, actual_rank, mode, loop)
            
            # Transform back from PCA space
            svd_approximation = transformed_lr @ U
        else:
            # Use Hadamard SVD function as fallback
            svd_approximation = Hadamard_channel_func(weight_float, low_rank_func)
    
    elif transform == "cov":
        try:
            if weight_float.shape[0] >= weight_float.shape[1]:
                # Covariance matrix computation: X^T @ X
                cov_matrix = weight_float.mT @ weight_float
                cov_dim = weight_float.shape[1]
                jitter = 1e-5 * torch.eye(cov_dim, device=weight_float.device, dtype=weight_float.dtype)
                cov_matrix += jitter

                S = torch.linalg.cholesky(cov_matrix)
                S_inv = torch.linalg.inv(S)

                # Transform: X @ S_inv
                whitened_weight = weight_float @ S_inv
                whitened_weight_lr = svd_approx(whitened_weight, actual_rank, mode, loop)
                # Transform back: result @ S
                svd_approximation = whitened_weight_lr @ S
            else:
                # Covariance matrix computation: X @ X^T
                cov_matrix = weight_float @ weight_float.mT
                cov_dim = weight_float.shape[0]
                jitter = 1e-5 * torch.eye(cov_dim, device=weight_float.device, dtype=weight_float.dtype)
                cov_matrix += jitter

                S = torch.linalg.cholesky(cov_matrix)
                S_inv = torch.linalg.inv(S)

                # Transform: S_inv @ X
                whitened_weight = S_inv @ weight_float
                whitened_weight_lr = svd_approx(whitened_weight, actual_rank, mode, loop)
                # Transform back: S @ result
                svd_approximation = S @ whitened_weight_lr
        except torch._C._LinAlgError as e:
            svd_approximation = svd_approx(weight_float, actual_rank, mode, loop)
        
    elif transform == "none":
        # Standard low-rank approximation without transform
        svd_approximation = svd_approx(weight_float, actual_rank, mode, loop)
    else:
        raise ValueError(f"Unknown transform: {transform}. Choose 'none', 'hadamard', 'pca', or 'cov'.")

    # INT4 quantization
    svd_approximation, _, _ = per_channel_asymmetric_quantization(svd_approximation, 4)
    actual_rank = actual_rank // 4

    if use_quantized_residual:
        # Step 2: Calculate residual error (W - SVD(W))
        residual_error = weight_float - svd_approximation

        # Remove outliers (if sparsity > 0)
        residual_error, smallest_value, smallest_indices, largest_value, largest_indices = remove_outliers_2d(residual_error, sparsity)

        # Step 3: Standard quantization without AWQ (this will be overwritten if AWQ is used)
        quantized_residual, scale_factors, _ = per_channel_asymmetric_quantization(residual_error, weight_quant_bits)
        # quantized_residual = Hadamard_token_func(residual_error, weight_quantize_func)
        # quantized_residual, scale_factors = per_channel_symmetric_quantization(residual_error, weight_quant_bits)

        # Restore the original values at the smallest and largest k indices
        quantized_residual = restore_outliers_2d(quantized_residual, smallest_value, smallest_indices, largest_value, largest_indices)
        
        # Step 4: Final approximation = SVD(W) + Q(W - SVD(W))
        low_rank_weight = svd_approximation + quantized_residual        
        
        # Calculate residual error statistics
        compression_info['residual_error'] = {
            'original_norm': torch.norm(residual_error).item(),
            'quantized_norm': torch.norm(quantized_residual).item(),
            'quantization_error': torch.norm(residual_error - quantized_residual).item()
        }
    else:
        # Standard approximation
        low_rank_weight = svd_approximation


    if activation_scale is not None:   
        # Remove scale to get back to original space
        low_rank_weight = low_rank_weight / optimal_scale
        
        compression_info['awq_scale_applied'] = True
        compression_info['awq_scale_norm'] = torch.norm(optimal_scale).item()
    
    # Move back to original device and dtype
    return low_rank_weight.to(device=device, dtype=dtype), actual_rank, compression_info


def apply_low_rank_to_model(model: torch.nn.Module, rank_percentage: float, mode: str = "svd", loop: int = 3,
                          use_quantized_residual: bool = False, weight_quant_bits: int = 8, transform: str = "none", 
                          sparsity: float = 0.0, activation_scales: Optional[Dict[str, torch.Tensor]] = None):
    """
    Apply low-rank approximation to all linear layers in the model.
    
    Args:
        model: The model to modify
        rank_percentage: Percentage of rank to keep (0-100)
        mode: 'svd' for exact SVD, 'approx' for power iteration
        loop: Number of power iteration loops (only used for 'approx' mode)
        use_quantized_residual: Whether to use quantized residual for better approximation
        weight_quant_bits: Number of bits for residual quantization
        transform: Transform to apply before low-rank approximation ('none', 'hadamard', 'pca')
        sparsity: Sparsity ratio for outlier removal (0.0 to 1.0)
        activation_scales: Optional dictionary mapping layer names to activation scales for AWQ
    """
    method_str = f"{mode}"
    if use_quantized_residual:
        method_str += f" + {weight_quant_bits}-bit quantized residual"
    if transform != "none":
        method_str += f" + {transform} transform"
    
    logger.info(f"Applying low-rank approximation ({method_str}) with {rank_percentage}% rank retention to model weights...")
    
    modified_layers = 0
    total_original_bytes = 0
    total_compressed_bytes = 0
    
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            original_weight = module.weight.data.clone()
            # Get activation scale for this layer if available
            activation_scale = None
            if activation_scales is not None and name in activation_scales:
                activation_scale = activation_scales[name]
                logger.info(f"Using activation scale for layer {name}")
            
            # Apply low-rank approximation
            low_rank_weight, actual_rank, compression_info = apply_low_rank_approximation(
                original_weight, rank_percentage, mode, loop, use_quantized_residual, weight_quant_bits, transform, sparsity, activation_scale
            )
            module.weight.data = low_rank_weight
            
            # Memory accounting (bytes → MB)
            m, n = original_weight.shape
            element_size_bytes = original_weight.element_size()
            original_bytes = original_weight.numel() * element_size_bytes

            # Low-rank storage (U and V factors)
            svd_bytes = (m + n) * actual_rank * element_size_bytes

            # Quantized residual storage (values + per-row scale and zero-point)
            if use_quantized_residual:
                quantized_bytes = original_weight.numel() * (weight_quant_bits / 16.0) # FP16
                scale_bytes = m * element_size_bytes
                zero_point_bytes = m * 4  # int32 per row
                residual_bytes = quantized_bytes + scale_bytes + zero_point_bytes
            else:
                residual_bytes = 0.0

            compressed_bytes = svd_bytes + residual_bytes

            original_mb = original_bytes / (1024.0 * 1024.0)
            compressed_mb = compressed_bytes / (1024.0 * 1024.0)
            compression_rate = (compressed_bytes / original_bytes) if original_bytes > 0 else 0.0

            total_original_bytes += original_bytes
            total_compressed_bytes += compressed_bytes

            # Per-layer output: [layer], [shape], [orig MB], [compressed MB], [compression rate]
            logger.info(f"{name}, {original_weight.shape}, {original_mb:.3f} MB, {compressed_mb:.3f} MB, {compression_rate:.3f}")
            modified_layers += 1
    
    # Global memory reporting
    logger.info(f"Successfully applied low-rank approximation to {modified_layers} layers")
    total_original_gb = total_original_bytes / (1024.0 * 1024.0 * 1024.0)
    total_compressed_gb = total_compressed_bytes / (1024.0 * 1024.0 * 1024.0)
    global_compression_rate = (total_compressed_bytes / total_original_bytes) if total_original_bytes > 0 else 0.0
    logger.info("Final Global Info")
    logger.info(f"{total_original_gb:.3f} GB, {total_compressed_gb:.3f} GB, {global_compression_rate:.3f}")
