"""
Activation-aware Weight Quantization (AWQ) for residual quantization.

This module implements AWQ-based activation-aware residual quantization
for improving low-rank approximation with quantized residuals.
"""

import torch
import torch.nn as nn
from typing import Dict, List, Tuple, Optional
import logging
from tqdm import tqdm

logger = logging.getLogger(__name__)


def get_act_scale(x: torch.Tensor) -> torch.Tensor:
    """
    Calculate activation scale for each channel.
    
    Args:
        x: Input activation tensor [batch_size, seq_len, hidden_dim] or [batch_size, hidden_dim]
    
    Returns:
        Scale tensor [hidden_dim]
    """
    # Use flatten instead of view for better performance and safety
    if x.dim() == 3:
        x = x.flatten(0, -2)  # [batch_size * seq_len, hidden_dim]
    elif x.dim() != 2:
        raise ValueError(f"Expected 2D or 3D tensor, got {x.dim()}D")
    
    # Calculate mean absolute value per channel using optimized operations
    return x.abs().mean(dim=0)


def collect_linear_activations(model: nn.Module, dataloader, num_samples: int = 128) -> Dict[str, torch.Tensor]:
    """
    Collect activation statistics for all linear layers in the model.
    
    Args:
        model: The model to analyze
        dataloader: Dataloader for calibration data
        num_samples: Number of samples to use for calibration
    
    Returns:
        Dictionary mapping layer names to activation scales
    """
    activation_scales = {}
    handles = []
    
    # Pre-allocate tensors for accumulation to avoid repeated CPU-GPU transfers
    scale_accumulators = {}
    sample_counts = {}
    
    def hook_fn(name):
        def hook(module, input, output):
            # input is a tuple, we need the first element
            inp = input[0]
            
            if name not in scale_accumulators:
                # Initialize accumulator on the same device as input
                device = inp.device
                dtype = inp.dtype
                hidden_dim = inp.shape[-1]
                scale_accumulators[name] = torch.zeros(hidden_dim, device=device, dtype=dtype)
                sample_counts[name] = 0
            
            # Calculate activation scale directly on device
            act_scale = get_act_scale(inp)
            
            # Accumulate on device (no CPU transfer)
            scale_accumulators[name] += act_scale
            sample_counts[name] += 1
        return hook
    
    # Register hooks for all linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            handle = module.register_forward_hook(hook_fn(name))
            handles.append(handle)
    
    # Run calibration samples
    model.eval()
    samples_processed = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader)):
            if samples_processed >= num_samples:
                break
                
            inputs = batch
            
            # Forward pass to collect activations
            try:
                if hasattr(model, 'generate'):
                    # For generation models, just do a forward pass
                    if isinstance(inputs, dict) and 'input_ids' in inputs:
                        _ = model(**inputs, use_cache=False)
                    else:
                        # Skip if we can't process this batch format
                        continue
                else:
                    _ = model(inputs)
            except Exception as e:
                continue
                
            samples_processed += batch.get('input_ids', inputs).shape[0] if hasattr(batch.get('input_ids', inputs), 'shape') else 1
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    # Calculate averaged scales (single CPU transfer per layer)
    averaged_scales = {}
    for name in scale_accumulators:
        if sample_counts[name] > 0:
            # Single CPU transfer per layer instead of per sample
            averaged_scales[name] = (scale_accumulators[name] / sample_counts[name]).cpu()
        else:
            logger.warning(f"No activation scales collected for layer {name}")

    return averaged_scales


def collect_linear_activations_fast(model: nn.Module, dataloader, num_samples: int = 128) -> Dict[str, torch.Tensor]:
    """
    Fast version of activation collection with optimized batch processing.
    
    Args:
        model: The model to analyze
        dataloader: Dataloader for calibration data
        num_samples: Number of samples to use for calibration
    
    Returns:
        Dictionary mapping layer names to activation scales
    """
    activation_scales = {}
    handles = []
    
    # Pre-allocate tensors for accumulation
    scale_accumulators = {}
    sample_counts = {}
    
    def hook_fn(name):
        def hook(module, input, output):
            inp = input[0]
            
            if name not in scale_accumulators:
                device = inp.device
                dtype = inp.dtype
                hidden_dim = inp.shape[-1]
                scale_accumulators[name] = torch.zeros(hidden_dim, device=device, dtype=dtype)
                sample_counts[name] = 0
            
            # Optimized activation scale calculation
            if inp.dim() == 3:
                # For 3D tensors, use more efficient operations
                act_scale = inp.abs().mean(dim=(0, 1))  # Mean over batch and sequence dims
            else:
                act_scale = inp.abs().mean(dim=0)
            
            scale_accumulators[name] += act_scale
            sample_counts[name] += 1
        return hook
    
    # Register hooks for all linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            handle = module.register_forward_hook(hook_fn(name))
            handles.append(handle)
    
    # Run calibration samples with minimal error checking
    model.eval()
    samples_processed = 0
    
    with torch.no_grad():
        for batch in dataloader:
            if samples_processed >= num_samples:
                break
                
            try:
                if hasattr(model, 'generate') and isinstance(batch, dict) and 'input_ids' in batch:
                    _ = model(**batch, use_cache=False)
                else:
                    _ = model(batch)
                    
                samples_processed += batch.get('input_ids', batch).shape[0] if hasattr(batch.get('input_ids', batch), 'shape') else 1
            except:
                continue
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    # Calculate averaged scales
    averaged_scales = {}
    for name in scale_accumulators:
        if sample_counts[name] > 0:
            averaged_scales[name] = (scale_accumulators[name] / sample_counts[name]).cpu()
        else:
            logger.warning(f"No activation scales collected for layer {name}")

    return averaged_scales


def search_optimal_scale(
    weight: torch.Tensor,
    activation_scale: torch.Tensor,
    weight_quantize_func,
    low_rank_func,
    actual_rank,
    n_grid: int = 20
) -> Tuple[torch.Tensor, float, float]:
    """
    Search for optimal scale factor that minimizes quantization error.
    Optimizes: min ||X @ W.T - X @ (Low_rank(s*W) + Q(W-Low-rank(s*W))).T||^2
    This is approximated by minimizing the activation-scaled weight reconstruction error:
    min ||(W - W_recon) * act_scale||^2
    
    Args:
        weight: Original weight tensor [out_features, in_features]
        activation_scale: Activation scale for each input channel [in_features]
        weight_quantize_func: Function to quantize weights
        low_rank_func: Function to apply low-rank approximation
        rank_percentage: Percentage of rank to keep for low-rank approximation
        n_grid: Number of grid points to search
    
    Returns:
        Tuple of (optimal_scales, best_ratio)
    """
    device = weight.device
    dtype = weight.dtype
    
    # Ensure activation_scale is on the same device
    activation_scale = activation_scale.to(device)
    
    best_error = float('inf')
    best_ratio = -1
    best_scales = None
    
    # Grid search for optimal scale ratio
    for ratio in range(n_grid):
        ratio = ratio / n_grid
        
        # Calculate scales based on activation scale
        scales = activation_scale.pow(ratio).clamp(min=1e-4)
        scales = scales / (scales.max() * scales.min()).sqrt()
        scales = scales.view(1, -1)  # [1, in_features]
        
        # Apply scale to weight
        scaled_weight = weight * scales
        
        # Apply low-rank approximation to scaled weight
        scaled_weight_lr = low_rank_func(scaled_weight)
        
        # Calculate residual of scaled weight
        scaled_residual = scaled_weight - scaled_weight_lr
        
        # Quantize the residual
        quantized_residual = weight_quantize_func(scaled_residual)
        
        # Reconstruct: Low_rank(s*W) + Q(s*W - Low_rank(s*W))
        reconstructed_scaled = scaled_weight_lr + quantized_residual
        
        # Remove scale to get back to original space
        reconstructed = reconstructed_scaled / scales
        
        # Calculate output-aware reconstruction error
        # Error = ||(W - reconstructed) * activation_scale||^2, which approximates output error
        # error = ((weight - reconstructed) * activation_scale.view(1, -1)).float().pow(2).mean().item()

        error = (weight - reconstructed).float().pow(2).mean().item()
        
        if error < best_error:
            best_error = error
            best_ratio = ratio
            best_scales = scales.squeeze(0)  # Remove batch dimension
    
    if best_ratio == -1:
        logger.warning("Failed to find optimal scale, using uniform scale")
        best_scales = torch.ones_like(activation_scale)
    
    return best_scales.to(dtype), best_ratio, best_error


# def apply_activation_aware_residual_quant(
#     model: nn.Module,
#     activation_scales: Dict[str, torch.Tensor],
#     weight_quantize_func,
#     layer_name_to_module: Optional[Dict[str, nn.Module]] = None
# ) -> Dict[str, torch.Tensor]:
#     """
#     Apply activation-aware quantization to find optimal scales for each layer.
    
#     Args:
#         model: The model containing linear layers
#         activation_scales: Dictionary of activation scales for each layer
#         weight_quantize_func: Function to quantize weights
#         layer_name_to_module: Optional mapping from layer names to modules
    
#     Returns:
#         Dictionary mapping layer names to optimal scales
#     """
#     optimal_scales = {}
    
#     # Build layer name to module mapping if not provided
#     if layer_name_to_module is None:
#         layer_name_to_module = {}
#         for name, module in model.named_modules():
#             if isinstance(module, nn.Linear):
#                 layer_name_to_module[name] = module
    
#     # Find optimal scale for each layer
#     for name, module in layer_name_to_module.items():
#         if name not in activation_scales:
#             logger.warning(f"No activation scale found for layer {name}, skipping")
#             continue
            
#         # Note: We'll need the low-rank approximated weight
#         # This will be passed from weight_compression.py
#         # For now, we'll store the activation scale
#         optimal_scales[name] = activation_scales[name]
    
#     return optimal_scales


def find_optimal_scale_for_weight(
    weight: torch.Tensor,
    activation_scale: torch.Tensor,
    quantize_func,
    low_rank_func,
    rank_percentage: float,
    n_grid: int = 20
) -> Tuple[torch.Tensor, float]:
    """
    Find optimal scale for residual quantization using activation information.
    This function is called from weight_compression.py
    
    Args:
        weight: Original weight tensor
        activation_scale: Activation scale for this layer
        quantize_func: Quantization function
        low_rank_func: Low-rank approximation function
        rank_percentage: Percentage of rank to keep
        n_grid: Number of grid points to search
    
    Returns:
        Optimal scale tensor
    """
    optimal_scale, best_ratio, best_error = search_optimal_scale(
        weight, activation_scale, quantize_func, low_rank_func, rank_percentage, n_grid
    )
    
    logger.info(f"Found optimal scale with ratio {best_ratio:.3f} and best error {best_error}")
    return optimal_scale, best_error
