from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Optional
from transformers.configuration_utils import PretrainedConfig
import torch
from torch import nn

from utils.layer_utils import LayerConfig, LayerSchema, TransformerConfig


class InferenceMetric(ABC):
    """
    Abstract base class for computing inference metrics of models.
    """
    
    def __init__(self, target_speedup: float = 3.0):
        """
        Initialize the inference metric.
        
        Args:
            target_speedup (float): The target speedup ratio to achieve
        """
        self.target_speedup = target_speedup
        self.original_inference = 0.
        self.pruned_inference = 0.
    
    @abstractmethod
    def compute_original_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]]) -> float:
        """
        Compute the inference metric for the original (unpruned) model.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            
        Returns:
            float: The inference metric value for the original model
        """
        pass

    def get_original_inference(self) -> float:
        """
        Get the inference metric for the original (unpruned) model.
        """
        return self.original_inference
    
    @abstractmethod
    def compute_pruned_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]], 
                                layers_configs: List[LayerConfig]) -> float:
        """
        Compute the inference metric for the pruned model.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            layers_configs: Optional list of layer configurations after pruning
            
        Returns:
            float: The inference metric value for the pruned model
        """
        pass
    
    def get_pruned_inference(self) -> float:
        """
        Get the inference metric for the pruned model.
        """
        return self.pruned_inference
    
    @abstractmethod
    def get_current_speedup(self) -> float:
        """
        Get the current speedup ratio.
        
        Args:
            original_inference (float): The inference metric of the original model
            pruned_inference (float): The inference metric of the pruned model
            
        Returns:
            float: The current speedup ratio
        """
        pass
    
    @abstractmethod
    def is_target_speedup_achieved(self) -> bool:
        """
        Check if the target speedup has been achieved.
        
        Args:
            original_inference (float): The inference metric of the original model
            pruned_inference (float): The inference metric of the pruned model
            
        Returns:
            bool: True if target speedup is achieved, False otherwise
        """
        pass


# Factory function to create inference metrics
def create_inference_metric(metric_type: str = "sparsity", target_speedup: float = 3.0) -> InferenceMetric:
    """
    Factory function to create inference metrics.
    
    Args:
        metric_type (str): Type of inference metric ("sparsity" or "latency")
        target_speedup (float): Target speedup ratio
        
    Returns:
        InferenceMetric: The created inference metric instance
    """
    if metric_type.lower() == "sparsity":
        from inference.sparsity import SparsityInferenceMetric
        return SparsityInferenceMetric(target_speedup)
    elif metric_type.lower() == "latency":
        from inference.latency import LatencyInferenceMetric
        return LatencyInferenceMetric(target_speedup)
    else:
        raise ValueError(f"Unknown inference metric type: {metric_type}")
        
        
