import math
import logging

def calculate_gear_kcvt_compression_ratio(args, config):
    """
    Calculates and logs the KV cache compression ratio for the GEAR-KCVT method.
    """
    # Get model dimensions
    num_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    head_dim = hidden_size // num_heads
    seq_len = args.model_max_length
    batch_size = args.batch_size

    # Parameters from args
    quantize_bit = args.quantize_bit
    sparsity = args.left
    rank_k = args.rank
    rank_v = args.rankv

    # --- Original KV Cache Size (FP16) ---
    original_elements = batch_size * num_heads * seq_len * head_dim
    original_size_bytes = original_elements * 2  # FP16 = 2 bytes
    original_total_size_bytes = original_size_bytes * 2  # Key + Value

    # --- Compressed Key Cache Size ---
    # 1. Quantized part
    num_outliers_k = math.floor(original_elements * sparsity)
    num_quantized_k = original_elements - num_outliers_k
    
    # Outliers: FP16 value + int32 index
    size_outliers_k = num_outliers_k * 2
    
    # Quantized values: quantize_bit per value
    size_quant_vals_k = (num_quantized_k * quantize_bit) / 8
    
    # Metadata (scale, zero-point) for group_size = seq_len
    # One (mx, mn) pair per (batch, head, head_dim)
    size_metadata_k = batch_size * num_heads * head_dim * 2 * 2  # mx, mn as FP16
    
    size_quant_total_k = size_outliers_k + size_quant_vals_k + size_metadata_k

    # 2. Low-rank error part
    # Stored as two matrices: (b*h, s, r) and (b*h, d, r)
    size_error_lr_k = batch_size * num_heads * (seq_len + head_dim) * rank_k * 2 # FP16

    compressed_size_k = size_quant_total_k + size_error_lr_k


    # --- Compressed Value Cache Size ---
    # 1. Quantized part
    num_outliers_v = math.floor(original_elements * sparsity)
    num_quantized_v = original_elements - num_outliers_v

    size_outliers_v = num_outliers_v * 2
    size_quant_vals_v = (num_quantized_v * quantize_bit) / 8
    
    # Metadata (scale, zero-point) for group_size = num_heads * head_dim
    # One (mx, mn) pair per (batch, seq_len)
    size_metadata_v = batch_size * seq_len * 2 * 2

    size_quant_total_v = size_outliers_v + size_quant_vals_v + size_metadata_v

    # 2. Low-rank error part
    size_error_lr_v = batch_size * num_heads * (seq_len + head_dim) * rank_v * 2

    compressed_size_v = size_quant_total_v + size_error_lr_v
    
    # --- Final Calculation ---
    compressed_total_size_bytes = compressed_size_k + compressed_size_v
    ratio = (compressed_total_size_bytes / original_total_size_bytes) * 100

    logging.info("--- GEAR-KCVT Compression Ratio Analysis ---")
    logging.info(f"Original KV Cache Size: {original_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"Compressed KV Cache Size: {compressed_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"  - Key Cache: {compressed_size_k / 1024**2:.2f} MB "
                 f"(Quant: {size_quant_total_k / 1024**2:.2f} MB, LR-Error: {size_error_lr_k / 1024**2:.2f} MB)")
    logging.info(f"  - Value Cache: {compressed_size_v / 1024**2:.2f} MB "
                 f"(Quant: {size_quant_total_v / 1024**2:.2f} MB, LR-Error: {size_error_lr_v / 1024**2:.2f} MB)")
    logging.info(f"Compression Ratio: {ratio:.2f}% of original size")
    logging.info("---------------------------------------------")

def calculate_pcc_compact_compression_ratio(args, config):
    """
    Calculates and logs the KV cache compression ratio for the PCC_COV_OUTLIER_COMPACT method.
    """
    # Get model dimensions
    num_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    head_dim = hidden_size // num_heads
    seq_len = args.model_max_length
    batch_size = args.batch_size

    # Parameters from args
    quantize_bit = args.quantize_bit
    sparsity = args.left
    rank_k = args.rank
    rank_v = args.rankv

    # --- Original KV Cache Size (FP16) ---
    original_elements = batch_size * num_heads * seq_len * head_dim
    original_size_bytes = original_elements * 2  # FP16 = 2 bytes
    original_total_size_bytes = original_size_bytes * 2  # Key + Value

    # --- Compressed Key Cache Size (pcc_channelQ) ---
    # It's a sum of a low-rank part and a quantized residual part.

    # 1. Low-rank part storage
    size_lr_k = batch_size * num_heads * (seq_len + head_dim) * rank_k * 2  # FP16 factors

    # 2. Quantized residual part storage
    # Outliers
    num_outliers_k = math.floor(original_elements * sparsity)
    size_outliers_k = num_outliers_k * (2 + 4)  # FP16 value + int32 index
    # Quantized values
    num_quantized_k = original_elements - num_outliers_k
    size_quant_vals_k = (num_quantized_k * quantize_bit) / 8
    # Metadata (channel-wise, group_size=seq_len)
    size_metadata_k = batch_size * num_heads * head_dim * 2 * 2  # mx, mn as FP16
    
    size_quant_residual_k = size_outliers_k + size_quant_vals_k + size_metadata_k
    
    compressed_size_k = size_lr_k + size_quant_residual_k

    # --- Compressed Value Cache Size (pcc_tokenQ) ---
    # Also a sum of a low-rank part and a quantized residual part.

    # 1. Low-rank part storage
    size_lr_v = batch_size * num_heads * (seq_len + head_dim) * rank_v * 2  # FP16 factors
    
    # 2. Quantized residual part storage
    # Outliers
    num_outliers_v = math.floor(original_elements * sparsity)
    size_outliers_v = num_outliers_v * (2 + 4) # FP16 value + int32 index
    # Quantized values
    num_quantized_v = original_elements - num_outliers_v
    size_quant_vals_v = (num_quantized_v * quantize_bit) / 8
    # Metadata (token-wise, group_size=num_heads*head_dim)
    size_metadata_v = batch_size * seq_len * 2 * 2  # mx, mn as FP16

    size_quant_residual_v = size_outliers_v + size_quant_vals_v + size_metadata_v
    
    compressed_size_v = size_lr_v + size_quant_residual_v
    
    # --- Final Calculation ---
    compressed_total_size_bytes = compressed_size_k + compressed_size_v
    ratio = (compressed_total_size_bytes / original_total_size_bytes) * 100

    logging.info("--- PCC_COV_OUTLIER_COMPACT Compression Ratio Analysis ---")
    logging.info(f"Original KV Cache Size: {original_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"Compressed KV Cache Size: {compressed_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"  - Key Cache: {compressed_size_k / 1024**2:.2f} MB "
                 f"(LR-Part: {size_lr_k / 1024**2:.2f} MB, Quant-Residual: {size_quant_residual_k / 1024**2:.2f} MB)")
    logging.info(f"  - Value Cache: {compressed_size_v / 1024**2:.2f} MB "
                 f"(LR-Part: {size_lr_v / 1024**2:.2f} MB, Quant-Residual: {size_quant_residual_v / 1024**2:.2f} MB)")
    logging.info(f"Compression Ratio: {ratio:.2f}% of original size")
    logging.info("---------------------------------------------")


def calculate_kcvt_compression_ratio(args, config):
    """
    Calculates and logs the KV cache compression ratio for the KCVT method.
    """
    # Get model dimensions
    num_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    head_dim = hidden_size // num_heads
    seq_len = args.model_max_length
    batch_size = args.batch_size
    quantize_bit = args.quantize_bit

    # --- Original KV Cache Size (FP16) ---
    original_elements = batch_size * num_heads * seq_len * head_dim
    original_size_bytes = original_elements * 2  # FP16 = 2 bytes
    original_total_size_bytes = original_size_bytes * 2  # Key + Value

    # --- Compressed Key Cache Size ---
    # Quantization only, no outliers, group_size = seq_len
    size_quant_vals_k = (original_elements * quantize_bit) / 8
    # Metadata (channel-wise)
    size_metadata_k = batch_size * num_heads * head_dim * 2 * 2 # one (mx,mn) per channel row
    compressed_size_k = size_quant_vals_k + size_metadata_k

    # --- Compressed Value Cache Size ---
    # Quantization only, no outliers, group_size = num_heads * head_dim
    size_quant_vals_v = (original_elements * quantize_bit) / 8
    # Metadata (token-wise)
    size_metadata_v = batch_size * seq_len * 2 * 2 # one (mx,mn) per token
    compressed_size_v = size_quant_vals_v + size_metadata_v

    # --- Final Calculation ---
    compressed_total_size_bytes = compressed_size_k + compressed_size_v
    ratio = (compressed_total_size_bytes / original_total_size_bytes) * 100

    logging.info("--- KCVT Compression Ratio Analysis ---")
    logging.info(f"Original KV Cache Size: {original_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"Compressed KV Cache Size: {compressed_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"  - Key Cache (Channel-wise Quant): {compressed_size_k / 1024**2:.2f} MB")
    logging.info(f"  - Value Cache (Token-wise Quant): {compressed_size_v / 1024**2:.2f} MB")
    logging.info(f"Compression Ratio: {ratio:.2f}% of original size")
    logging.info("---------------------------------------------")


def calculate_kivi_v2_compression_ratio(args, config):
    """
    Calculates and logs the KV cache compression ratio for the KIVI_V2 method.
    """
    # Get model dimensions
    num_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    head_dim = hidden_size // num_heads
    seq_len = args.model_max_length
    batch_size = args.batch_size
    quantize_bit = args.quantize_bit
    group_size = args.group_size

    # --- Original KV Cache Size (FP16) ---
    original_elements = batch_size * num_heads * seq_len * head_dim
    original_size_bytes = original_elements * 2
    original_total_size_bytes = original_size_bytes * 2

    # --- Compressed Key Cache Size ---
    # Channel-wise quantization with specified group_size
    size_quant_vals_k = (original_elements * quantize_bit) / 8
    num_groups_k = seq_len // group_size
    size_metadata_k = batch_size * num_heads * head_dim * num_groups_k * 2 * 2
    compressed_size_k = size_quant_vals_k + size_metadata_k

    # --- Compressed Value Cache Size ---
    # Token-wise quantization with specified group_size
    size_quant_vals_v = (original_elements * quantize_bit) / 8
    num_groups_v = (num_heads * head_dim) // group_size
    size_metadata_v = batch_size * seq_len * num_groups_v * 2 * 2
    compressed_size_v = size_quant_vals_v + size_metadata_v

    # --- Final Calculation ---
    compressed_total_size_bytes = compressed_size_k + compressed_size_v
    ratio = (compressed_total_size_bytes / original_total_size_bytes) * 100

    logging.info("--- KIVI_V2 Compression Ratio Analysis ---")
    logging.info(f"Original KV Cache Size: {original_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"Compressed KV Cache Size: {compressed_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"  - Key Cache (Channel-wise Quant): {compressed_size_k / 1024**2:.2f} MB")
    logging.info(f"  - Value Cache (Token-wise Quant): {compressed_size_v / 1024**2:.2f} MB")
    logging.info(f"Compression Ratio: {ratio:.2f}% of original size")
    logging.info("---------------------------------------------")


def calculate_palu_50_compression_ratio(args, config):
    """
    Calculates and logs the KV cache compression ratio for the PALU_50 method.
    This method applies SVD approximation, and optionally quantization with outlier removal.
    """
    # Get model dimensions
    num_heads = config.num_attention_heads
    hidden_size = config.hidden_size
    head_dim = hidden_size // num_heads
    seq_len = args.model_max_length
    batch_size = args.batch_size
    quantize_bit = args.quantize_bit
    sparsity = args.left
    rank_k = args.rank
    rank_v = args.rankv

    # --- Original KV Cache Size (FP16) ---
    original_elements = batch_size * num_heads * seq_len * head_dim
    original_size_bytes = original_elements * 2
    original_total_size_bytes = original_size_bytes * 2

    if quantize_bit == 0:
        # SVD approximation only. Storage is for the two factor matrices (U and V).
        # Assuming FP16 storage for factors.
        compressed_size_k = batch_size * num_heads * (seq_len + head_dim) * rank_k * 2
        compressed_size_v = batch_size * num_heads * (seq_len + head_dim) * rank_v * 2
    else:
        # SVD is performed, then the reconstructed tensor is quantized with outliers.
        # Storage is for the quantized tensor and its metadata.
        num_outliers = math.floor(original_elements * sparsity)
        num_quantized = original_elements - num_outliers
        
        # Per user's last edit, only storing outlier value (FP16), not index.
        size_outliers = num_outliers * 2

        # --- Compressed Key Cache Size ---
        size_quant_vals_k = (num_quantized * quantize_bit) / 8
        size_metadata_k = batch_size * num_heads * head_dim * 2 * 2
        compressed_size_k = size_outliers + size_quant_vals_k + size_metadata_k

        # --- Compressed Value Cache Size ---
        size_quant_vals_v = (num_quantized * quantize_bit) / 8
        size_metadata_v = batch_size * seq_len * 2 * 2
        compressed_size_v = size_outliers + size_quant_vals_v + size_metadata_v
        
    compressed_total_size_bytes = compressed_size_k + compressed_size_v

    # --- Final Calculation ---
    ratio = (compressed_total_size_bytes / original_total_size_bytes) * 100

    logging.info("--- PALU_50 Compression Ratio Analysis ---")
    logging.info(f"Original KV Cache Size: {original_total_size_bytes / 1024**2:.2f} MB")
    logging.info(f"Compressed KV Cache Size: {compressed_total_size_bytes / 1024**2:.2f} MB")
    if quantize_bit != 0:
        logging.info(f"  - Key Cache (Quantized): {compressed_size_k / 1024**2:.2f} MB")
        logging.info(f"  - Value Cache (Quantized): {compressed_size_v / 1024**2:.2f} MB")
    else:
        logging.info("  - SVD approximation applied (quantize_bit is 0).")
        logging.info(f"  - Key Cache (Low-Rank Factors): {compressed_size_k / 1024**2:.2f} MB")
        logging.info(f"  - Value Cache (Low-Rank Factors): {compressed_size_v / 1024**2:.2f} MB")
    logging.info(f"Compression Ratio: {ratio:.2f}% of original size")
    logging.info("---------------------------------------------")
