from typing import List, Tuple, Optional, cast
from torch import nn
from utils.layer_utils import LayerConfig, LayerSchema, LayerType, TransformerLayerSchema
from inference.inference import InferenceMetric


class SparsityInferenceMetric(InferenceMetric):
    """
    Implementation of inference metric based on model sparsity.
    
    This metric calculates the effective sparsity by using the updated layer
    configurations after pruning, making it suitable for zero-out column pruning 
    where weights are set to zero but not physically removed from tensors.
    
    The pruned parameter count is based on the remaining heads and dimensions
    in the updated configuration, which correctly reflects the effective model
    capacity after structured pruning.
    """
    
    def compute_original_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]]) -> float:
        """
        Compute the sparsity-based inference metric for the original model.
        For the original model, this represents the total number of parameters.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            
        Returns:
            float: The total number of parameters in the original model
        """
        if self.original_inference > 0.0:
            raise ValueError("Original inference has already been computed")
        total_params = 0

        for layer_name, layer, layer_schema in model_layers:
            if layer_schema.layer_type == LayerType.transformer:
                transformer_schema = layer_schema  # type: ignore
                layer_params = self._compute_transformer_original_params(layer, cast(TransformerLayerSchema, transformer_schema))
                total_params += layer_params

            else:
                raise NotImplementedError(f"Layer type {layer_schema.layer_type} not implemented for original inference computation")

        self.original_inference = total_params
        # Initialize pruned_inference to original_inference when no pruning has been done
        self.pruned_inference = total_params
        return float(total_params)
    
    def compute_pruned_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]], 
                                layers_configs: List[LayerConfig]) -> float:
        """
        Compute the sparsity-based inference metric for the pruned model.
        This represents the number of parameters that would remain after pruning.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            layers_configs: List of layer configurations after pruning
            
        Returns:
            float: The number of parameters that would remain after pruning
        """
        # Calculate remaining parameters based on layer configurations
        total_remaining_params = 0

        for idx, (layer_name, layer, layer_schema) in enumerate(model_layers):
            if idx < len(layers_configs):
                layer_config = layers_configs[idx]
                if layer_schema.layer_type == LayerType.transformer:
                    transformer_schema = layer_schema  # type: ignore
                    remaining_params = self._compute_transformer_pruned_params(layer, idx, cast(TransformerLayerSchema, transformer_schema), layer_config)
                    total_remaining_params += remaining_params

                else:
                    raise NotImplementedError(f"Layer type {layer_schema.layer_type} not implemented for pruned inference computation")
            else:
                # If no config provided for this layer, count all parameters
                layer_params = sum(p.numel() for p in layer.parameters())
                total_remaining_params += layer_params

        self.pruned_inference = total_remaining_params
        return float(total_remaining_params)
    
    def _compute_transformer_original_params(self, layer: nn.Module, layer_schema: TransformerLayerSchema) -> int:
        """
        Compute the total number of parameters in a transformer layer.
        
        Args:
            layer: The transformer layer module
            layer_schema: The layer schema for the transformer layer
            
        Returns:
            int: Total number of parameters in the transformer layer
        """
        # For transformer layers, we need to calculate parameters based on the layer structure
        # This is a simplified calculation - adjust based on your specific model architecture
        total_params = 0
        
        # Count parameters for attention heads (Q, K, V, O projections)
        for key, layer_specs in layer_schema.layers.items():
            if key in ['q', 'k', 'v', 'o']:
                module = getattr(layer, layer_specs.module_name)
                linear_layer = getattr(module, layer_specs.attribute)
                total_params += linear_layer.weight.numel()
                if linear_layer.bias is not None:
                    total_params += linear_layer.bias.numel()
        
        # Count parameters for FFN layers
        for key, layer_specs in layer_schema.layers.items():
            if key in ['fc1', 'fc2', 'fc3']:
                module = getattr(layer, layer_specs.module_name)
                linear_layer = getattr(module, layer_specs.attribute)
                total_params += linear_layer.weight.numel()
                if linear_layer.bias is not None:
                    total_params += linear_layer.bias.numel()
        
        return total_params
    
    def _compute_transformer_pruned_params(self, layer: nn.Module, idx: int, layer_schema: TransformerLayerSchema, 
                                         layer_config: LayerConfig) -> int:
        """
        Compute remaining parameters by counting pruned blocks and converting to equivalent parameter reduction.
        This approach counts actual pruned heads and FFN dimensions, then calculates sparsity based on block ratios.
        """
        from utils.layer_utils import TransformerConfig
        from typing import cast
        
        transformer_config = cast(TransformerConfig, layer_config)
        
        # Get original dimensions from actual tensor shapes
        original_params = self._compute_transformer_original_params(layer, layer_schema)
        
        # Get original dimensions
        q_spec = layer_schema.layers['q']
        q_module = getattr(layer, q_spec.module_name)
        q_layer = getattr(q_module, q_spec.attribute)
        original_num_heads = q_layer.weight.shape[0] // transformer_config.head_size
        
        fc1_spec = layer_schema.layers['fc1']
        fc1_module = getattr(layer, fc1_spec.module_name)
        fc1_layer = getattr(fc1_module, fc1_spec.attribute)
        original_intermediate_dim = fc1_layer.weight.shape[0]
        
        # Calculate pruned blocks
        pruned_heads = original_num_heads - transformer_config.num_heads
        pruned_ffn_dim = original_intermediate_dim - transformer_config.intermediate_dimension
        
        # Calculate block-based sparsity ratios
        head_sparsity_ratio = pruned_heads / original_num_heads if original_num_heads > 0 else 0.0
        ffn_sparsity_ratio = pruned_ffn_dim / original_intermediate_dim if original_intermediate_dim > 0 else 0.0
        
        # Calculate parameter contributions for each component
        total_head_params = 0  # Q, K, V, O combined
        total_ffn_params = 0   # fc1, fc2, fc3 combined
        
        for key, layer_specs in layer_schema.layers.items():
            module = getattr(layer, layer_specs.module_name)
            linear_layer = getattr(module, layer_specs.attribute)
            
            weight_params = linear_layer.weight.numel()
            bias_params = linear_layer.bias.numel() if linear_layer.bias is not None else 0
            layer_total_params = weight_params + bias_params
            
            if key in ['q', 'k', 'v', 'o']:
                total_head_params += layer_total_params
            elif key in ['fc1', 'fc2', 'fc3']:
                total_ffn_params += layer_total_params
        
        # Calculate remaining parameters based on sparsity ratios
        remaining_head_params = int(total_head_params * (1.0 - head_sparsity_ratio))
        remaining_ffn_params = int(total_ffn_params * (1.0 - ffn_sparsity_ratio))
        remaining_params = remaining_head_params + remaining_ffn_params

        # self.pruned_inference = remaining_params
        
        return remaining_params
    
    def get_current_speedup(self) -> float:
        """
        Get the current pruning percentage (sparsity ratio).
        Returns a value between 0 and 1, where:
        - 0 means no pruning (0% parameters removed)
        - 1 means complete pruning (100% parameters removed)
        
        Args:
            original_inference (float): The total number of parameters in the original model
            pruned_inference (float): The number of parameters remaining after pruning
            
        Returns:
            float: The pruning percentage (0.0 to 1.0)
        """
        if self.original_inference <= 0.:
            return 0.0
        
        # Calculate pruning percentage: (original - pruned) / original
        pruning_percentage = (self.original_inference - self.pruned_inference) / self.original_inference
        
        # Ensure the result is between 0 and 1
        return max(0.0, min(1.0, pruning_percentage))
    
    def is_target_speedup_achieved(self) -> bool:
        """
        Check if the target pruning percentage has been achieved.
        
        Args:
            original_inference (float): The total number of parameters in the original model
            pruned_inference (float): The number of parameters remaining after pruning
            
        Returns:
            bool: True if target pruning percentage is achieved, False otherwise
        """
        current_pruning = self.get_current_speedup()
        
        # For sparsity, target_speedup represents the desired pruning percentage
        # e.g., target_speedup = 0.5 means we want to prune 50% of parameters
        return current_pruning >= self.target_speedup


