from typing import List, Tuple, Optional, Dict, Any
import torch
import torch.nn as nn
import numpy as np
from sklearn.linear_model import LinearRegression
from dataclasses import dataclass
import json
import os
from pathlib import Path

from utils.layer_utils import LayerConfig, LayerSchema, TransformerConfig
from inference.inference import InferenceMetric


@dataclass
class ModuleLatencyData:
    """Data structure for module-level latency measurements."""
    module_type: str  # "MHA" or "FFN"
    sequence_length: int
    d_model: int
    h: int  # number of active heads for MHA
    d_ffn: int  # FFN intermediate dimension
    latency_ms: float


class TwoStageLatencyModel:
    """
    Two-stage learned latency model for Transformer modules.
    
    Stage 1: Module-level latency prediction using linear regression
    Stage 2: Aggregation for total model latency using another linear model
    """
    
    def __init__(self, model_name: str = "default"):
        self.model_name = model_name
        self.f_mha = LinearRegression()  # Regressor for Multi-Head Attention
        self.f_ffn = LinearRegression()  # Regressor for Feed-Forward Network
        self.aggregation_model = LinearRegression()  # Stage 2 aggregation model
        
        # Training data storage
        self.mha_training_data = []
        self.ffn_training_data = []
        self.aggregation_training_data = []
        
        # Model state
        self.stage1_trained = False
        self.stage2_trained = False
        
    def add_module_measurement(self, module_data: ModuleLatencyData):
        """Add a module-level latency measurement for training."""
        features = self._extract_module_features(module_data)
        
        if module_data.module_type == "MHA":
            self.mha_training_data.append((features, module_data.latency_ms))
        elif module_data.module_type == "FFN":
            self.ffn_training_data.append((features, module_data.latency_ms))
        else:
            raise ValueError(f"Unknown module type: {module_data.module_type}")
    
    def add_aggregation_measurement(self, model_config: Dict[str, Any], total_latency_ms: float):
        """Add an end-to-end model latency measurement for Stage 2 training."""
        # Extract module-level predictions for this model configuration
        module_predictions = []
        for layer_idx, layer_config in enumerate(model_config.get('layers', [])):
            # MHA prediction
            mha_features = np.array([[
                model_config['sequence_length'],
                model_config['d_model'],
                layer_config.get('num_heads', 0),
                0  # d_ffn not used for MHA
            ]])
            if self.stage1_trained:
                mha_latency = self.f_mha.predict(mha_features)[0]
                module_predictions.append(mha_latency)
            
            # FFN prediction
            ffn_features = np.array([[
                model_config['sequence_length'],
                model_config['d_model'],
                0,  # h not used for FFN
                layer_config.get('d_ffn', 0)
            ]])
            if self.stage1_trained:
                ffn_latency = self.f_ffn.predict(ffn_features)[0]
                module_predictions.append(ffn_latency)
        
        self.aggregation_training_data.append((module_predictions, total_latency_ms))
    
    def _extract_module_features(self, module_data: ModuleLatencyData) -> np.ndarray:
        """Extract feature vector x^(ℓ) as defined in Equation (14)."""
        return np.array([[
            module_data.sequence_length,  # S
            module_data.d_model,          # d_model
            module_data.h,                # h^(ℓ) for MHA, 0 for FFN
            module_data.d_ffn             # d_ffn^(ℓ) for FFN, 0 for MHA
        ]])
    
    def train_stage1(self):
        """Train Stage 1: Module-level latency regressors."""
        if not self.mha_training_data or not self.ffn_training_data:
            raise ValueError("Insufficient training data for Stage 1")
        
        # Train MHA regressor
        mha_features = np.array([data[0] for data in self.mha_training_data])
        mha_latencies = np.array([data[1] for data in self.mha_training_data])
        self.f_mha.fit(mha_features, mha_latencies)
        
        # Train FFN regressor
        ffn_features = np.array([data[0] for data in self.ffn_training_data])
        ffn_latencies = np.array([data[1] for data in self.ffn_training_data])
        self.f_ffn.fit(ffn_features, ffn_latencies)
        
        self.stage1_trained = True
        print(f"Stage 1 training completed. MHA R²: {self.f_mha.score(mha_features, mha_latencies):.3f}, "
              f"FFN R²: {self.f_ffn.score(ffn_features, ffn_latencies):.3f}")
    
    def train_stage2(self):
        """Train Stage 2: Aggregation model for total model latency."""
        if not self.stage1_trained:
            raise ValueError("Stage 1 must be trained before Stage 2")
        
        if not self.aggregation_training_data:
            raise ValueError("Insufficient aggregation training data for Stage 2")
        
        # Prepare features: [α₀, α₁, α₂, ..., α_B] where α₀ is intercept
        # and α_b corresponds to module b's predicted latency
        aggregation_features = []
        total_latencies = []
        
        for module_predictions, total_latency in self.aggregation_training_data:
            # Create feature vector: [1, pred_1, pred_2, ..., pred_B]
            features = [1.0] + module_predictions  # 1.0 for intercept term
            aggregation_features.append(features)
            total_latencies.append(total_latency)
        
        aggregation_features = np.array(aggregation_features)
        total_latencies = np.array(total_latencies)
        
        self.aggregation_model.fit(aggregation_features, total_latencies)
        self.stage2_trained = True
        
        print(f"Stage 2 training completed. Aggregation R²: {self.aggregation_model.score(aggregation_features, total_latencies):.3f}")
    
    def predict_module_latency(self, module_type: str, sequence_length: int, 
                             d_model: int, h: int = 0, d_ffn: int = 0) -> float:
        """Predict latency for a single module using Stage 1 model."""
        if not self.stage1_trained:
            raise ValueError("Stage 1 model not trained")
        
        if module_type == "MHA":
            features = np.array([[
                sequence_length,
                d_model,
                h,
                0  # d_ffn not used for MHA
            ]])
            return self.f_mha.predict(features)[0]
        elif module_type == "FFN":
            features = np.array([[
                sequence_length,
                d_model,
                0,  # h not used for FFN
                d_ffn
            ]])
            return self.f_ffn.predict(features)[0]
        else:
            raise ValueError(f"Unknown module type: {module_type}")
    
    def predict_total_latency(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]], 
                            layers_configs: List[LayerConfig]) -> float:
        """Predict total model latency using both stages."""
        if not self.stage1_trained or not self.stage2_trained:
            raise ValueError("Both Stage 1 and Stage 2 models must be trained")
        
        # Get model configuration
        sequence_length = 2048  # Default, should be extracted from model
        d_model = layers_configs[0].d_model if layers_configs else 512
        
        # Predict latency for each module
        module_predictions = []
        for layer_idx, (layer_name, layer_module, layer_schema) in enumerate(model_layers):
            if layer_idx < len(layers_configs):
                layer_config = layers_configs[layer_idx]
                
                # MHA prediction
                mha_latency = self.predict_module_latency(
                    "MHA", sequence_length, d_model, 
                    h=layer_config.num_heads
                )
                module_predictions.append(mha_latency)
                
                # FFN prediction
                ffn_latency = self.predict_module_latency(
                    "FFN", sequence_length, d_model,
                    d_ffn=layer_config.intermediate_dimension
                )
                module_predictions.append(ffn_latency)
        
        # Aggregate using Stage 2 model
        features = [1.0] + module_predictions  # [α₀, α₁, α₂, ..., α_B]
        features = np.array(features).reshape(1, -1)
        
        return self.aggregation_model.predict(features)[0]
    
    def save_model(self, save_path: str):
        """Save the trained models to disk."""
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        
        # Save Stage 1 models
        if self.stage1_trained:
            import joblib
            joblib.dump(self.f_mha, save_path / "f_mha.joblib")
            joblib.dump(self.f_ffn, save_path / "f_ffn.joblib")
        
        # Save Stage 2 model
        if self.stage2_trained:
            joblib.dump(self.aggregation_model, save_path / "aggregation_model.joblib")
        
        # Save metadata
        metadata = {
            "model_name": self.model_name,
            "stage1_trained": self.stage1_trained,
            "stage2_trained": self.stage2_trained,
            "mha_samples": len(self.mha_training_data),
            "ffn_samples": len(self.ffn_training_data),
            "aggregation_samples": len(self.aggregation_training_data)
        }
        
        with open(save_path / "metadata.json", "w") as f:
            json.dump(metadata, f, indent=2)
    
    def load_model(self, load_path: str):
        """Load the trained models from disk."""
        load_path = Path(load_path)
        
        if not load_path.exists():
            raise FileNotFoundError(f"Model path not found: {load_path}")
        
        # Load metadata
        with open(load_path / "metadata.json", "r") as f:
            metadata = json.load(f)
        
        self.model_name = metadata["model_name"]
        
        # Load Stage 1 models
        if metadata["stage1_trained"]:
            import joblib
            self.f_mha = joblib.load(load_path / "f_mha.joblib")
            self.f_ffn = joblib.load(load_path / "f_ffn.joblib")
            self.stage1_trained = True
        
        # Load Stage 2 model
        if metadata["stage2_trained"]:
            import joblib
            self.aggregation_model = joblib.load(load_path / "aggregation_model.joblib")
            self.stage2_trained = True


class LatencyInferenceMetric(InferenceMetric):
    """
    Implementation of inference metric based on two-stage learned latency model.
    
    This implementation follows the HAP-E paper's latency estimation approach:
    - Stage 1: Module-level latency prediction using linear regression
    - Stage 2: Aggregation for total model latency using another linear model
    """
    
    def __init__(self, target_speedup: float = 3.0, model_name: str = "default"):
        super().__init__(target_speedup)
        self.latency_model = TwoStageLatencyModel(model_name)
        self.sequence_length = 2048  # Default sequence length
        
    def set_sequence_length(self, sequence_length: int):
        """Set the sequence length for latency prediction."""
        self.sequence_length = sequence_length
    
    def add_calibration_data(self, module_data: ModuleLatencyData):
        """Add calibration data for training the latency model."""
        self.latency_model.add_module_measurement(module_data)
    
    def add_aggregation_calibration_data(self, model_config: Dict[str, Any], total_latency_ms: float):
        """Add end-to-end calibration data for Stage 2 training."""
        self.latency_model.add_aggregation_measurement(model_config, total_latency_ms)
    
    def train_latency_model(self):
        """Train both stages of the latency model."""
        self.latency_model.train_stage1()
        self.latency_model.train_stage2()
    
    def compute_original_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]]) -> float:
        """
        Compute the latency-based inference metric for the original model.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            
        Returns:
            float: The latency in milliseconds for the original model
        """
        if not self.latency_model.stage1_trained or not self.latency_model.stage2_trained:
            # Fallback: estimate based on model size if no trained model
            return self._estimate_latency_fallback(model_layers)
        
        # Extract original layer configurations
        original_configs = []
        for layer_name, layer_module, layer_schema in model_layers:
            if hasattr(layer_module, 'config'):
                config = layer_module.config
                if hasattr(config, 'num_attention_heads') and hasattr(config, 'intermediate_size'):
                    original_configs.append(TransformerConfig(
                        d_model=config.hidden_size,
                        num_heads=config.num_attention_heads,
                        intermediate_dimension=config.intermediate_size
                    ))
        
        if not original_configs:
            return self._estimate_latency_fallback(model_layers)
        
        # Predict latency using trained model
        predicted_latency = self.latency_model.predict_total_latency(model_layers, original_configs)
        self.original_inference = predicted_latency
        return predicted_latency
    
    def compute_pruned_inference(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]], 
                                layers_configs: List[LayerConfig]) -> float:
        """
        Compute the latency-based inference metric for the pruned model.
        
        Args:
            model_layers: List of tuples containing (layer_name, layer_module, layer_schema)
            layers_configs: List of layer configurations after pruning
            
        Returns:
            float: The latency in milliseconds for the pruned model
        """
        if not self.latency_model.stage1_trained or not self.latency_model.stage2_trained:
            # Fallback: estimate based on pruned model size
            return self._estimate_latency_fallback(model_layers, layers_configs)
        
        # Predict latency using trained model
        predicted_latency = self.latency_model.predict_total_latency(model_layers, layers_configs)
        self.pruned_inference = predicted_latency
        return predicted_latency
    
    def _estimate_latency_fallback(self, model_layers: List[Tuple[str, nn.Module, LayerSchema]], 
                                 layers_configs: Optional[List[LayerConfig]] = None) -> float:
        """
        Fallback latency estimation based on model size when no trained model is available.
        
        This is a simple heuristic that estimates latency based on the number of parameters.
        """
        total_params = 0
        
        for layer_name, layer_module, layer_schema in model_layers:
            layer_params = sum(p.numel() for p in layer_module.parameters())
            total_params += layer_params
        
        # Simple heuristic: assume ~1ms per 1M parameters
        estimated_latency = total_params / 1_000_000.0
        
        return max(1.0, estimated_latency)  # Minimum 1ms latency
    
    def get_current_speedup(self) -> float:
        """
        Get the current latency speedup ratio.
        Returns a value >= 1, where:
        - 1.0 means no speedup (same latency)
        - 2.0 means 2x speedup (half the latency)
        - 3.0 means 3x speedup (one-third the latency)
        
        Returns:
            float: The latency speedup ratio (>= 1.0)
        """
        if self.pruned_inference <= 0:
            return 1.0  # No speedup if pruned inference is invalid
        
        if self.original_inference <= 0:
            return 1.0  # No speedup if original inference is invalid
        
        # Calculate speedup: original_latency / pruned_latency
        speedup = self.original_inference / self.pruned_inference
        
        # Ensure the result is at least 1.0 (no negative speedup)
        return max(1.0, speedup)
    
    def is_target_speedup_achieved(self) -> bool:
        """
        Check if the target latency speedup has been achieved.
        
        Returns:
            bool: True if target speedup is achieved, False otherwise
        """
        current_speedup = self.get_current_speedup()
        
        # For latency, target_speedup represents the desired speedup ratio
        # e.g., target_speedup = 2.0 means we want 2x speedup
        return current_speedup >= self.target_speedup
    
    def save_latency_model(self, save_path: str):
        """Save the trained latency model to disk."""
        self.latency_model.save_model(save_path)
    
    def load_latency_model(self, load_path: str):
        """Load a trained latency model from disk."""
        self.latency_model.load_model(load_path) 