"""
Pruning methods for CACTUS.

This module implements various pruning techniques including:
- Global/Local L1/L2 pruning
- Structured/Unstructured pruning
- Mask generation and application
"""

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Tuple, Optional
from abc import ABC, abstractmethod


class BasePruning(ABC):
    """Base class for all pruning methods."""
    
    def __init__(self, sparsity: float):
        """
        Args:
            sparsity: Target sparsity level (fraction of weights to prune)
        """
        assert 0 <= sparsity < 1, "Sparsity must be in [0, 1)"
        self.sparsity = sparsity
    
    @abstractmethod
    def create_mask(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Create pruning mask for the model."""
        pass
    
    def apply_mask(self, model: nn.Module, mask: Dict[str, torch.Tensor]) -> nn.Module:
        """Apply pruning mask to model weights."""
        return apply_pruning_mask(model, mask)


class GlobalL1Pruning(BasePruning):
    """Global L1 magnitude-based pruning."""
    
    def create_mask(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Create mask by pruning smallest L1 magnitude weights globally."""
        # Collect all weights and their magnitudes
        all_weights = []
        param_info = []
        
        for name, param in model.named_parameters():
            if param.requires_grad and len(param.shape) > 1:  # Only weight matrices, not biases
                weights_flat = param.data.abs().flatten()
                all_weights.append(weights_flat)
                param_info.append((name, param.shape, len(weights_flat)))
        
        # Concatenate all weights
        all_weights_cat = torch.cat(all_weights)
        
        # Find threshold for global pruning
        num_params_to_prune = int(self.sparsity * len(all_weights_cat))
        if num_params_to_prune > 0:
            threshold = torch.kthvalue(all_weights_cat, num_params_to_prune).values
        else:
            threshold = 0.0
        
        # Create masks
        masks = {}
        start_idx = 0
        
        for name, shape, size in param_info:
            param = dict(model.named_parameters())[name]
            end_idx = start_idx + size
            
            # Create mask for this parameter
            mask = (param.data.abs() > threshold).float()
            masks[name] = mask
            
            start_idx = end_idx
        
        return masks


class LocalL1Pruning(BasePruning):
    """Local L1 magnitude-based pruning (per layer)."""
    
    def create_mask(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Create mask by pruning smallest L1 magnitude weights in each layer."""
        masks = {}
        
        for name, param in model.named_parameters():
            if param.requires_grad and len(param.shape) > 1:  # Only weight matrices
                weights_flat = param.data.abs().flatten()
                num_params_to_prune = int(self.sparsity * len(weights_flat))
                
                if num_params_to_prune > 0:
                    threshold = torch.kthvalue(weights_flat, num_params_to_prune).values
                    mask = (param.data.abs() > threshold).float()
                else:
                    mask = torch.ones_like(param.data)
                
                masks[name] = mask
        
        return masks


class GlobalL2Pruning(BasePruning):
    """Global L2 magnitude-based pruning."""
    
    def create_mask(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Create mask by pruning smallest L2 magnitude weights globally."""
        # Collect all weights and their L2 magnitudes
        all_weights = []
        param_info = []
        
        for name, param in model.named_parameters():
            if param.requires_grad and len(param.shape) > 1:
                weights_flat = param.data.pow(2).flatten()  # L2 magnitude squared
                all_weights.append(weights_flat)
                param_info.append((name, param.shape, len(weights_flat)))
        
        # Concatenate all weights
        all_weights_cat = torch.cat(all_weights)
        
        # Find threshold for global pruning
        num_params_to_prune = int(self.sparsity * len(all_weights_cat))
        if num_params_to_prune > 0:
            threshold = torch.kthvalue(all_weights_cat, num_params_to_prune).values
        else:
            threshold = 0.0
        
        # Create masks
        masks = {}
        start_idx = 0
        
        for name, shape, size in param_info:
            param = dict(model.named_parameters())[name]
            end_idx = start_idx + size
            
            # Create mask for this parameter
            mask = (param.data.pow(2) > threshold).float()
            masks[name] = mask
            
            start_idx = end_idx
        
        return masks


class StructuredL2Pruning(BasePruning):
    """Structured L2 pruning (prune entire filters/neurons)."""
    
    def create_mask(self, model: nn.Module) -> Dict[str, torch.Tensor]:
        """Create mask by pruning entire filters based on L2 norm."""
        masks = {}
        
        for name, param in model.named_parameters():
            if param.requires_grad and len(param.shape) > 1:
                if 'conv' in name.lower() and len(param.shape) == 4:
                    # Convolutional layer: prune entire filters (output channels)
                    filter_norms = param.data.norm(dim=(1, 2, 3))  # L2 norm per filter
                    num_filters_to_prune = int(self.sparsity * param.shape[0])
                    
                    if num_filters_to_prune > 0:
                        _, indices_to_prune = torch.topk(filter_norms, num_filters_to_prune, largest=False)
                        mask = torch.ones_like(param.data)
                        mask[indices_to_prune] = 0
                    else:
                        mask = torch.ones_like(param.data)
                        
                elif 'fc' in name.lower() or 'linear' in name.lower():
                    # Fully connected layer: prune entire neurons (output dimensions)
                    neuron_norms = param.data.norm(dim=1)  # L2 norm per output neuron
                    num_neurons_to_prune = int(self.sparsity * param.shape[0])
                    
                    if num_neurons_to_prune > 0:
                        _, indices_to_prune = torch.topk(neuron_norms, num_neurons_to_prune, largest=False)
                        mask = torch.ones_like(param.data)
                        mask[indices_to_prune] = 0
                    else:
                        mask = torch.ones_like(param.data)
                else:
                    # Default to unstructured pruning
                    mask = torch.ones_like(param.data)
                
                masks[name] = mask
        
        return masks


def create_pruning_mask(model: nn.Module, method: str, sparsity: float) -> Dict[str, torch.Tensor]:
    """
    Create pruning mask using specified method.
    
    Args:
        model: PyTorch model to prune
        method: Pruning method ('global_l1', 'local_l1', 'global_l2', 'structured_l2')
        sparsity: Target sparsity level
        
    Returns:
        Dictionary mapping parameter names to binary masks
    """
    if method == 'global_l1':
        pruner = GlobalL1Pruning(sparsity)
    elif method == 'local_l1': 
        pruner = LocalL1Pruning(sparsity)
    elif method == 'global_l2':
        pruner = GlobalL2Pruning(sparsity)
    elif method == 'structured_l2':
        pruner = StructuredL2Pruning(sparsity)
    else:
        raise ValueError(f"Unknown pruning method: {method}")
    
    return pruner.create_mask(model)


def apply_pruning_mask(model: nn.Module, mask: Dict[str, torch.Tensor]) -> nn.Module:
    """
    Apply pruning mask to model weights in-place.
    
    Args:
        model: PyTorch model
        mask: Dictionary mapping parameter names to binary masks
        
    Returns:
        The model with masked weights (same object, modified in-place)
    """
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in mask:
                param.data *= mask[name]
    
    return model


def get_sparsity(model: nn.Module, mask: Optional[Dict[str, torch.Tensor]] = None) -> float:
    """
    Calculate current sparsity of the model.
    
    Args:
        model: PyTorch model
        mask: Optional mask to compute sparsity of
        
    Returns:
        Sparsity as fraction of zero weights
    """
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if param.requires_grad and len(param.shape) > 1:
            if mask is not None and name in mask:
                # Use mask to determine zeros
                zero_params += (mask[name] == 0).sum().item()
                total_params += mask[name].numel()
            else:
                # Count actual zero weights
                zero_params += (param.data == 0).sum().item()
                total_params += param.numel()
    
    return zero_params / total_params if total_params > 0 else 0.0


if __name__ == "__main__":
    # Test pruning methods
    from ..models import create_cnn7_mnist
    
    print("Testing pruning methods...")
    
    # Create test model
    model = create_cnn7_mnist()
    print(f"Original sparsity: {get_sparsity(model):.4f}")
    
    # Test different pruning methods
    methods = ['global_l1', 'local_l1', 'global_l2', 'structured_l2']
    target_sparsity = 0.5
    
    for method in methods:
        # Create fresh model
        test_model = create_cnn7_mnist()
        
        # Create and apply mask
        mask = create_pruning_mask(test_model, method, target_sparsity)
        apply_pruning_mask(test_model, mask)
        
        actual_sparsity = get_sparsity(test_model)
        print(f"{method}: target={target_sparsity:.2f}, actual={actual_sparsity:.4f}")
    
    print("Pruning tests passed!") 