#!/usr/bin/env python3
"""
Realistic Evaluation Framework for GATv2-NS3 Hybrid IDS.
Provides proper evaluation with realistic baselines and performance expectations.
"""

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
import time
from sklearn.metrics import (
    roc_auc_score, average_precision_score, f1_score, 
    precision_score, recall_score, accuracy_score,
    confusion_matrix, classification_report
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import xgboost as xgb

from ..utils.common import GraphData, get_logger, to_device
from ..models.gatv2_ids import GATv2IDS
from ..models.baselines import (
    NodeMLP, GraphSAGE_IDS, GIN_IDS,
    RandomForest_IDS, XGBoost_IDS, LogisticRegression_IDS, 
    MLP_IDS_New, GraphSAGE_IDS_New, GIN_IDS_New
)
from ..simulation.curiosity_loop import CuriosityLoopFeedback
# Remove old evaluation import - using built-in evaluation functions


@dataclass
class EvaluationResults:
    """Comprehensive evaluation results."""
    model_name: str
    metrics: Dict[str, float]
    confusion_matrix: np.ndarray
    classification_report: str
    inference_time_ms: float
    memory_usage_mb: float
    explanation_time_ms: float = 0.0
    simulation_feedback_score: float = 0.0
    metadata: Dict[str, Any] = None


@dataclass
class BaselineComparison:
    """Comparison results between models."""
    target_model: str
    baseline_models: List[str]
    performance_gaps: Dict[str, float]  # Positive = target better, negative = baseline better
    statistical_significance: Dict[str, float]  # p-values
    summary: str


class RealisticEvaluationFramework:
    """
    Comprehensive evaluation framework with realistic baselines and expectations.
    """
    
    def __init__(self, 
                 device: torch.device = None,
                 random_seed: int = 42,
                 cross_validation_folds: int = 5):
        """
        Initialize evaluation framework.
        
        Args:
            device: Device for model evaluation
            random_seed: Random seed for reproducibility
            cross_validation_folds: Number of CV folds for robust evaluation
        """
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.random_seed = random_seed
        self.cv_folds = cross_validation_folds
        self.logger = get_logger("realistic_evaluation")
        
        # Set seeds for reproducibility
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(random_seed)
        
        # Realistic expected performance ranges for network intrusion detection
        # Updated based on actual challenging, realistic attack patterns
        self.expected_performance_ranges = {
            "traditional_ml": {
                "roc_auc": (0.45, 0.65),
                "f1": (0.20, 0.35),
                "precision": (0.12, 0.25),
                "recall": (0.30, 0.95)
            },
            "basic_gnn": {
                "roc_auc": (0.45, 0.65),
                "f1": (0.20, 0.35),
                "precision": (0.12, 0.25),
                "recall": (0.30, 0.90)
            },
            "advanced_gnn": {
                "roc_auc": (0.45, 0.65),
                "f1": (0.20, 0.35),
                "precision": (0.12, 0.25),
                "recall": (0.30, 0.90)
            },
            "hybrid_simulation": {
                "roc_auc": (0.45, 0.65),
                "f1": (0.20, 0.35),
                "precision": (0.12, 0.25),
                "recall": (0.30, 0.90)
            }
        }

    def evaluate_comprehensive(self, 
                             models: Dict[str, nn.Module],
                             test_graphs: List[GraphData],
                             include_simulation_feedback: bool = True,
                             include_explanation_analysis: bool = True) -> Dict[str, EvaluationResults]:
        """
        Perform comprehensive evaluation of multiple models.
        
        Args:
            models: Dictionary of model_name -> model
            test_graphs: Test dataset
            include_simulation_feedback: Whether to evaluate simulation feedback
            include_explanation_analysis: Whether to evaluate explanations
            
        Returns:
            Dictionary of evaluation results per model
        """
        
        results = {}
        
        self.logger.info(f"Starting comprehensive evaluation of {len(models)} models on {len(test_graphs)} graphs")
        
        for model_name, model in models.items():
            self.logger.info(f"Evaluating {model_name}...")
            
            try:
                result = self._evaluate_single_model(
                    model, model_name, test_graphs,
                    include_simulation_feedback, include_explanation_analysis
                )
                results[model_name] = result
                
                # Log key metrics
                self.logger.info(f"{model_name} Results:")
                self.logger.info(f"  ROC AUC: {result.metrics['roc_auc']:.3f}")
                self.logger.info(f"  F1 Score: {result.metrics['f1']:.3f}")
                self.logger.info(f"  Inference Time: {result.inference_time_ms:.2f}ms")
                
            except Exception as e:
                self.logger.error(f"Failed to evaluate {model_name}: {e}")
                continue
        
        # Validate results against expected ranges
        self._validate_performance_ranges(results)
        
        return results

    def _evaluate_single_model(self, 
                             model: nn.Module,
                             model_name: str,
                             test_graphs: List[GraphData],
                             include_simulation_feedback: bool,
                             include_explanation_analysis: bool) -> EvaluationResults:
        """Evaluate a single model comprehensively."""
        
        model.eval()
        model.to(self.device)
        
        all_predictions = []
        all_probabilities = []
        all_labels = []
        inference_times = []
        explanation_times = []
        simulation_scores = []
        
        with torch.no_grad():
            for graph in test_graphs:
                graph_device = to_device(graph, self.device)
                
                # Measure inference time
                start_time = time.time()
                
                try:
                    output = model(graph_device)
                    inference_time = (time.time() - start_time) * 1000  # Convert to ms
                    inference_times.append(inference_time)
                    
                    # Extract predictions and probabilities
                    if isinstance(output, dict):
                        logits = output.get("node_logits", output.get("logits"))
                        edge_attention = output.get("edge_attn")
                    else:
                        logits = output
                        edge_attention = None
                    
                    if logits is None:
                        continue
                    
                    probabilities = torch.softmax(logits, dim=-1)[:, 1]  # Probability of attack class
                    predictions = torch.argmax(logits, dim=-1)
                    
                    all_predictions.extend(predictions.cpu().numpy())
                    all_probabilities.extend(probabilities.cpu().numpy())
                    all_labels.extend(graph_device.y_node.cpu().numpy())
                    
                    # Evaluate explanation quality if available
                    if include_explanation_analysis and edge_attention is not None:
                        exp_start = time.time()
                        exp_time = self._evaluate_explanation_quality(
                            edge_attention, graph_device
                        )
                        explanation_times.append((time.time() - exp_start) * 1000)
                    
                    # Evaluate simulation feedback if enabled
                    if include_simulation_feedback and edge_attention is not None:
                        sim_score = self._evaluate_simulation_feedback(
                            edge_attention, graph_device, logits
                        )
                        simulation_scores.append(sim_score)
                
                except Exception as e:
                    self.logger.warning(f"Failed to evaluate graph {graph.graph_id}: {e}")
                    continue
        
        # Compute metrics
        if not all_labels:
            raise ValueError(f"No valid predictions for {model_name}")
        
        metrics = self._compute_comprehensive_metrics(
            np.array(all_labels), 
            np.array(all_predictions), 
            np.array(all_probabilities)
        )
        
        # Compute confusion matrix and classification report
        cm = confusion_matrix(all_labels, all_predictions)
        class_report = classification_report(all_labels, all_predictions, 
                                           target_names=["Normal", "Attack"],
                                           zero_division=0)
        
        # Estimate memory usage
        memory_usage = self._estimate_memory_usage(model)
        
        return EvaluationResults(
            model_name=model_name,
            metrics=metrics,
            confusion_matrix=cm,
            classification_report=class_report,
            inference_time_ms=float(np.mean(inference_times)) if inference_times else 0.0,
            memory_usage_mb=memory_usage,
            explanation_time_ms=float(np.mean(explanation_times)) if explanation_times else 0.0,
            simulation_feedback_score=float(np.mean(simulation_scores)) if simulation_scores else 0.0,
            metadata={
                "num_graphs_evaluated": len(test_graphs),
                "total_nodes_evaluated": len(all_labels),
                "device": str(self.device)
            }
        )

    def _compute_comprehensive_metrics(self, 
                                     y_true: np.ndarray, 
                                     y_pred: np.ndarray, 
                                     y_prob: np.ndarray) -> Dict[str, float]:
        """Compute comprehensive evaluation metrics."""
        
        metrics = {}
        
        try:
            # Basic classification metrics
            metrics["accuracy"] = accuracy_score(y_true, y_pred)
            metrics["precision"] = precision_score(y_true, y_pred, zero_division=0)
            metrics["recall"] = recall_score(y_true, y_pred, zero_division=0)
            metrics["f1"] = f1_score(y_true, y_pred, zero_division=0)
            
            # ROC and PR AUC
            if len(np.unique(y_true)) > 1:
                metrics["roc_auc"] = roc_auc_score(y_true, y_prob)
                metrics["pr_auc"] = average_precision_score(y_true, y_prob)
            else:
                metrics["roc_auc"] = 0.0
                metrics["pr_auc"] = 0.0
            
            # Additional metrics for IDS evaluation
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            
            # False Positive Rate
            metrics["fpr"] = fp / (fp + tn) if (fp + tn) > 0 else 0.0
            
            # True Negative Rate (Specificity)
            metrics["tnr"] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
            
            # Matthews Correlation Coefficient
            if (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) > 0:
                metrics["mcc"] = ((tp * tn) - (fp * fn)) / np.sqrt(
                    (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
                )
            else:
                metrics["mcc"] = 0.0
            
            # Balanced accuracy
            metrics["balanced_accuracy"] = (metrics["recall"] + metrics["tnr"]) / 2
            
            # Attack detection rate (same as recall)
            metrics["attack_detection_rate"] = metrics["recall"]
            
            # False alarm rate (same as FPR)
            metrics["false_alarm_rate"] = metrics["fpr"]
            
        except Exception as e:
            self.logger.error(f"Error computing metrics: {e}")
            # Return minimal metrics
            metrics = {
                "accuracy": 0.0, "precision": 0.0, "recall": 0.0, "f1": 0.0,
                "roc_auc": 0.0, "pr_auc": 0.0, "fpr": 1.0, "tnr": 0.0,
                "mcc": 0.0, "balanced_accuracy": 0.0, "attack_detection_rate": 0.0,
                "false_alarm_rate": 1.0
            }
        
        return metrics

    def _evaluate_explanation_quality(self, 
                                    edge_attention: torch.Tensor,
                                    graph_data: GraphData) -> float:
        """Evaluate quality of attention-based explanations."""
        
        try:
            # Compute attention entropy (higher = more uncertain)
            if edge_attention.numel() == 0:
                return 0.0
            
            # Normalize attention weights
            attn_probs = torch.softmax(edge_attention, dim=0)
            
            # Compute entropy
            entropy = -torch.sum(attn_probs * torch.log(attn_probs + 1e-8))
            
            # Compute sparsity (how focused the attention is)
            sparsity = torch.sum(attn_probs > 0.1) / attn_probs.shape[0]
            
            # Combine metrics (want moderate entropy, high sparsity)
            quality_score = 1.0 - (entropy.item() / np.log(len(attn_probs))) + (1.0 - sparsity.item())
            
            return max(0.0, min(1.0, quality_score))
            
        except Exception as e:
            self.logger.warning(f"Failed to evaluate explanation quality: {e}")
            return 0.0

    def _evaluate_simulation_feedback(self, 
                                    edge_attention: torch.Tensor,
                                    graph_data: GraphData,
                                    predictions: torch.Tensor) -> float:
        """Evaluate quality of simulation feedback using real NS-3 integration."""
        
        try:
            # Use real NS-3 simulation to evaluate feedback quality
            from ..simulation.enhanced_feedback import create_enhanced_feedback_system
            from ..simulation.ns3_client import NS3Client
            
            # Create real NS-3 feedback system
            feedback_system = create_enhanced_feedback_system()
            
            # Analyze attention uncertainty to determine if simulation is needed
            attention_analysis = feedback_system.analyze_attention_uncertainty(
                edge_attention, graph_data.edge_index, graph_data
            )
            
            if not attention_analysis.requires_simulation:
                # No simulation needed - attention is sufficiently focused
                return 0.2  # Low feedback quality score
            
            # Generate real simulation feedback
            simulation_feedback = feedback_system.generate_simulation_feedback(
                graph_data, attention_analysis, predictions
            )
            
            # Evaluate feedback quality based on real NS-3 simulation results
            # Higher alignment between attention and simulation = better feedback
            alignment_score = simulation_feedback.attention_alignment_score
            
            # Consider simulation confidence and network realism
            sim_metadata = simulation_feedback.simulation_report.get("simulation_metadata", {})
            processing_confidence = sim_metadata.get("analysis_confidence", 0.5)
            
            # Network KPI realism check
            latency = simulation_feedback.simulation_report.get("latency_ms", 0)
            throughput = simulation_feedback.simulation_report.get("throughput_mbps", 0)
            
            # Realistic network metrics indicate higher quality feedback
            realism_score = 1.0 if (0.1 <= latency <= 100 and 1 <= throughput <= 1000) else 0.5
            
            # Combined feedback quality score
            feedback_quality = (
                alignment_score * 0.5 +           # Attention-simulation alignment
                processing_confidence * 0.3 +     # Simulation confidence
                realism_score * 0.2               # Network realism
            )
            
            self.logger.info(f"Real NS-3 simulation feedback quality: {feedback_quality:.3f}")
            self.logger.info(f"  Alignment score: {alignment_score:.3f}")
            self.logger.info(f"  Processing confidence: {processing_confidence:.3f}")
            self.logger.info(f"  Network realism: {realism_score:.3f}")
            
            return max(0.0, min(1.0, feedback_quality))
            
        except Exception as e:
            self.logger.error(f"Real NS-3 simulation feedback failed: {e}")
            raise RuntimeError(
                "NS-3 simulation feedback is required for evaluation. "
                "Ensure NS-3 is properly installed and functional."
            ) from e

    def _estimate_memory_usage(self, model: nn.Module) -> float:
        """Estimate model memory usage in MB."""
        
        try:
            param_size = 0
            buffer_size = 0
            
            for param in model.parameters():
                param_size += param.nelement() * param.element_size()
            
            for buffer in model.buffers():
                buffer_size += buffer.nelement() * buffer.element_size()
            
            total_size = param_size + buffer_size
            return total_size / (1024 * 1024)  # Convert to MB
            
        except Exception as e:
            self.logger.warning(f"Failed to estimate memory usage: {e}")
            return 0.0

    def _validate_performance_ranges(self, results: Dict[str, EvaluationResults]):
        """Validate results against expected performance ranges."""
        
        for model_name, result in results.items():
            # Determine expected performance category
            if "simulation" in model_name.lower() or "hybrid" in model_name.lower():
                category = "hybrid_simulation"
            elif "gatv2" in model_name.lower() or "sage" in model_name.lower():
                category = "advanced_gnn"
            elif "gin" in model_name.lower() or "gcn" in model_name.lower():
                category = "basic_gnn"
            else:
                category = "traditional_ml"
            
            expected = self.expected_performance_ranges[category]
            
            # Check if results are within expected ranges
            for metric, (min_val, max_val) in expected.items():
                if metric in result.metrics:
                    actual_val = result.metrics[metric]
                    
                    if actual_val < min_val:
                        self.logger.warning(
                            f"{model_name} {metric} ({actual_val:.3f}) below expected range ({min_val:.3f}-{max_val:.3f})"
                        )
                    elif actual_val > max_val:
                        self.logger.warning(
                            f"{model_name} {metric} ({actual_val:.3f}) suspiciously high (expected {min_val:.3f}-{max_val:.3f})"
                        )
                    else:
                        self.logger.info(
                            f"{model_name} {metric} ({actual_val:.3f}) within expected range ✓"
                        )

    def create_baseline_models(self, 
                             in_dim_node: int, 
                             in_dim_edge: int,
                             train_graphs: List[GraphData]) -> Dict[str, nn.Module]:
        """Create comprehensive set of baseline models for comparison."""
        
        models = {}
        
        # Traditional ML baselines (trained on flattened features)
        try:
            # Extract features for traditional ML
            X_train, y_train = self._extract_tabular_features(train_graphs)
            
            if len(X_train) == 0:
                self.logger.warning("No training data extracted for traditional ML models")
                return models
            
            expected_features = X_train.shape[1]
            self.logger.info(f"Traditional ML models expect {expected_features} features")
            
            # Random Forest (using new implementation)
            rf_model = RandomForest_IDS(random_state=self.random_seed, n_estimators=100)
            rf_model.fit(X_train, y_train)
            models["Random Forest"] = TabularModelWrapper(rf_model, "random_forest", expected_features)
            
            # XGBoost (using new implementation)
            xgb_model = XGBoost_IDS(random_state=self.random_seed, n_estimators=100, learning_rate=0.01)
            xgb_model.fit(X_train, y_train)
            models["XGBoost"] = TabularModelWrapper(xgb_model, "xgboost", expected_features)
            
            # Logistic Regression (using new implementation)
            lr_model = LogisticRegression_IDS(random_state=self.random_seed, max_iter=3000, class_weight='balanced')
            lr_model.fit(X_train, y_train)
            models["Logistic Regression"] = TabularModelWrapper(lr_model, "logistic", expected_features)
            
        except Exception as e:
            self.logger.warning(f"Failed to create traditional ML baselines: {e}")
        
        # Neural network baselines
        try:
            # Simple MLP (original)
            models["MLP"] = NodeMLP(in_dim_node, hidden=64, num_classes=2)
            
            # GraphSAGE (original)
            models["GraphSAGE"] = GraphSAGE_IDS(in_dim_node, in_dim_edge, hidden=64)
            
            # GIN (original)
            models["GIN"] = GIN_IDS(in_dim_node, in_dim_edge, hidden=64)
            
            # Basic GATv2 (without simulation feedback)
            models["GATv2 (Basic)"] = GATv2IDS(in_dim_node, in_dim_edge, hidden=64, layers=2, heads=2)
            
            # Enhanced GATv2 (with better hyperparameters)
            models["GATv2 (Enhanced)"] = GATv2IDS(in_dim_node, in_dim_edge, hidden=128, layers=3, heads=4)
            
            # New implementations from methods_to_compare
            models["MLP (New)"] = MLP_IDS_New(in_dim_node, hidden=64, num_classes=2)
            models["GraphSAGE (New)"] = GraphSAGE_IDS_New(in_dim_node, hidden=64)
            models["GIN (New)"] = GIN_IDS_New(in_dim_node, hidden=64)
            
        except Exception as e:
            self.logger.warning(f"Failed to create neural network baselines: {e}")
        
        self.logger.info(f"Created {len(models)} baseline models")
        return models

    def _extract_tabular_features(self, graphs: List[GraphData]) -> Tuple[np.ndarray, np.ndarray]:
        """Extract consistent tabular features from graphs for traditional ML models."""
        
        features = []
        labels = []
        
        # Determine feature dimensions from first graph
        first_graph = graphs[0] if graphs else None
        if first_graph is None:
            return np.array([]), np.array([])
        
        node_dim = first_graph.x.shape[1]
        edge_dim = first_graph.edge_attr.shape[1] if first_graph.edge_attr is not None else 0
        
        # Calculate consistent edge aggregation size
        edge_agg_dim = edge_dim * 4 if edge_dim > 0 else 0  # mean, std, max, min
        
        for graph in graphs:
            # Node-level features
            node_features = graph.x.numpy()
            node_labels = graph.y_node.numpy()
            
            # Graph-level edge aggregations (consistent across all graphs)
            if graph.edge_attr is not None and edge_dim > 0:
                edge_features = graph.edge_attr.numpy()
                if edge_features.shape[0] > 0:
                    edge_stats = [
                        edge_features.mean(axis=0),
                        edge_features.std(axis=0),
                        edge_features.max(axis=0),
                        edge_features.min(axis=0)
                    ]
                    edge_agg = np.concatenate(edge_stats)
                else:
                    # Handle empty edge features
                    edge_agg = np.zeros(edge_agg_dim)
            else:
                edge_agg = np.zeros(edge_agg_dim)
            
            # Ensure edge_agg has consistent size
            if len(edge_agg) != edge_agg_dim:
                edge_agg = np.pad(edge_agg, (0, max(0, edge_agg_dim - len(edge_agg))))[:edge_agg_dim]
            
            # Combine node features with graph-level statistics
            for i, node_feat in enumerate(node_features):
                combined_feat = np.concatenate([node_feat, edge_agg])
                features.append(combined_feat)
                labels.append(node_labels[i])
        
        return np.array(features), np.array(labels)

    def compare_with_baselines(self, 
                             target_model_results: EvaluationResults,
                             baseline_results: Dict[str, EvaluationResults]) -> BaselineComparison:
        """Compare target model performance with baselines."""
        
        performance_gaps = {}
        significance_tests = {}
        
        target_metrics = target_model_results.metrics
        
        for baseline_name, baseline_result in baseline_results.items():
            baseline_metrics = baseline_result.metrics
            
            # Compute performance gaps
            gaps = {}
            for metric in ["roc_auc", "f1", "precision", "recall"]:
                if metric in target_metrics and metric in baseline_metrics:
                    gap = target_metrics[metric] - baseline_metrics[metric]
                    gaps[metric] = gap
            
            performance_gaps[baseline_name] = gaps
            
            # Simple significance test (would need more sophisticated testing in practice)
            avg_gap = np.mean(list(gaps.values()))
            significance_tests[baseline_name] = abs(avg_gap)  # Placeholder p-value
        
        # Generate summary
        best_baseline = max(baseline_results.keys(), 
                          key=lambda x: baseline_results[x].metrics.get("f1", 0))
        
        best_gap = performance_gaps[best_baseline]["f1"]
        
        if best_gap > 0.05:
            summary = f"Target model significantly outperforms best baseline ({best_baseline}) by {best_gap:.3f} F1"
        elif best_gap > 0.01:
            summary = f"Target model moderately outperforms best baseline ({best_baseline}) by {best_gap:.3f} F1"
        elif best_gap > -0.01:
            summary = f"Target model performs similarly to best baseline ({best_baseline})"
        else:
            summary = f"Target model underperforms best baseline ({best_baseline}) by {abs(best_gap):.3f} F1"
        
        return BaselineComparison(
            target_model=target_model_results.model_name,
            baseline_models=list(baseline_results.keys()),
            performance_gaps=performance_gaps,
            statistical_significance=significance_tests,
            summary=summary
        )

    def generate_evaluation_report(self, 
                                 results: Dict[str, EvaluationResults],
                                 comparison: Optional[BaselineComparison] = None) -> str:
        """Generate comprehensive evaluation report."""
        
        report = []
        report.append("=" * 80)
        report.append("COMPREHENSIVE EVALUATION REPORT")
        report.append("=" * 80)
        report.append("")
        
        # Summary table
        report.append("PERFORMANCE SUMMARY")
        report.append("-" * 50)
        report.append(f"{'Model':<20} {'ROC AUC':<8} {'F1':<6} {'Precision':<9} {'Recall':<6} {'Time(ms)':<8}")
        report.append("-" * 50)
        
        for model_name, result in results.items():
            metrics = result.metrics
            report.append(f"{model_name:<20} "
                         f"{metrics.get('roc_auc', 0):<8.3f} "
                         f"{metrics.get('f1', 0):<6.3f} "
                         f"{metrics.get('precision', 0):<9.3f} "
                         f"{metrics.get('recall', 0):<6.3f} "
                         f"{result.inference_time_ms:<8.2f}")
        
        report.append("")
        
        # Detailed results for each model
        for model_name, result in results.items():
            report.append(f"DETAILED RESULTS: {model_name}")
            report.append("-" * 40)
            
            # Core metrics
            metrics = result.metrics
            report.append(f"ROC AUC:           {metrics.get('roc_auc', 0):.4f}")
            report.append(f"PR AUC:            {metrics.get('pr_auc', 0):.4f}")
            report.append(f"F1 Score:          {metrics.get('f1', 0):.4f}")
            report.append(f"Precision:         {metrics.get('precision', 0):.4f}")
            report.append(f"Recall:            {metrics.get('recall', 0):.4f}")
            report.append(f"Accuracy:          {metrics.get('accuracy', 0):.4f}")
            report.append(f"False Positive Rate: {metrics.get('fpr', 0):.4f}")
            report.append(f"Attack Detection Rate: {metrics.get('attack_detection_rate', 0):.4f}")
            
            # Performance metrics
            report.append(f"Inference Time:    {result.inference_time_ms:.2f} ms")
            report.append(f"Memory Usage:      {result.memory_usage_mb:.2f} MB")
            
            if result.explanation_time_ms > 0:
                report.append(f"Explanation Time:  {result.explanation_time_ms:.2f} ms")
            
            if result.simulation_feedback_score > 0:
                report.append(f"Simulation Score:  {result.simulation_feedback_score:.3f}")
            
            # Confusion Matrix
            report.append("Confusion Matrix:")
            cm = result.confusion_matrix
            report.append(f"                 Predicted")
            report.append(f"                 Normal  Attack")
            report.append(f"Actual Normal    {cm[0,0]:6d}  {cm[0,1]:6d}")
            report.append(f"       Attack    {cm[1,0]:6d}  {cm[1,1]:6d}")
            
            report.append("")
        
        # Comparison with baselines
        if comparison:
            report.append("BASELINE COMPARISON")
            report.append("-" * 40)
            report.append(comparison.summary)
            report.append("")
            
            report.append("Performance Gaps (Target - Baseline):")
            for baseline_name, gaps in comparison.performance_gaps.items():
                report.append(f"{baseline_name}:")
                for metric, gap in gaps.items():
                    sign = "+" if gap >= 0 else ""
                    report.append(f"  {metric}: {sign}{gap:.4f}")
            report.append("")
        
        # Recommendations
        report.append("RECOMMENDATIONS")
        report.append("-" * 40)
        
        # Find best performing model
        best_model = max(results.keys(), key=lambda x: results[x].metrics.get("f1", 0))
        best_f1 = results[best_model].metrics.get("f1", 0)
        
        if best_f1 > 0.85:
            report.append("✓ Excellent performance achieved")
        elif best_f1 > 0.75:
            report.append("✓ Good performance, consider further optimization")
        elif best_f1 > 0.65:
            report.append("⚠ Moderate performance, significant improvement needed")
        else:
            report.append("❌ Poor performance, major issues need addressing")
        
        # Check for suspicious results
        for model_name, result in results.items():
            metrics = result.metrics
            if metrics.get("roc_auc", 0) > 0.98:
                report.append(f"⚠ {model_name}: Suspiciously high ROC AUC ({metrics['roc_auc']:.3f}) - check for overfitting")
            if metrics.get("precision", 0) == 1.0 and metrics.get("recall", 0) > 0.9:
                report.append(f"⚠ {model_name}: Perfect precision with high recall - possible data leakage")
        
        return "\n".join(report)


class TabularModelWrapper(nn.Module):
    """Wrapper for traditional ML models to work with graph data."""
    
    def __init__(self, sklearn_model, model_type: str, expected_features: int):
        super().__init__()
        self.sklearn_model = sklearn_model
        self.model_type = model_type
        self.expected_features = expected_features
    
    def forward(self, graph_data: GraphData) -> Dict[str, torch.Tensor]:
        """Forward pass for traditional ML models."""
        
        # Extract features using the same method as training
        features = self._extract_consistent_features(graph_data)
        
        # Ensure feature dimension matches
        if features.shape[1] != self.expected_features:
            # Pad or truncate to match expected dimensions
            if features.shape[1] < self.expected_features:
                padding = np.zeros((features.shape[0], self.expected_features - features.shape[1]))
                features = np.concatenate([features, padding], axis=1)
            else:
                features = features[:, :self.expected_features]
        
        # Get predictions
        if hasattr(self.sklearn_model, 'predict_proba'):
            probabilities = self.sklearn_model.predict_proba(features)
            if probabilities.shape[1] == 2:
                logits = torch.tensor(probabilities, dtype=torch.float)
            else:
                # Binary classification, single probability
                probs = probabilities[:, 0]
                logits = torch.stack([1-probs, probs], dim=1)
        else:
            # For models without probability prediction
            predictions = self.sklearn_model.predict(features)
            logits = torch.zeros(len(predictions), 2)
            logits[range(len(predictions)), predictions] = 1.0
        
        return {"node_logits": logits, "edge_attn": None, "node_emb": None}
    
    def _extract_consistent_features(self, graph_data: GraphData) -> np.ndarray:
        """Extract features consistently with training."""
        
        # Node-level features
        node_features = graph_data.x.cpu().numpy()
        
        # Graph-level edge aggregations
        if graph_data.edge_attr is not None:
            edge_features = graph_data.edge_attr.cpu().numpy()
            if edge_features.shape[0] > 0:
                edge_stats = [
                    edge_features.mean(axis=0),
                    edge_features.std(axis=0),
                    edge_features.max(axis=0),
                    edge_features.min(axis=0)
                ]
                edge_agg = np.concatenate(edge_stats)
            else:
                # Handle empty edge features
                edge_dim = edge_features.shape[1] if len(edge_features.shape) > 1 else 10
                edge_agg = np.zeros(edge_dim * 4)
        else:
            edge_agg = np.zeros(40)  # Default edge aggregation size (10 * 4)
        
        # Combine node features with graph-level statistics
        features = []
        for node_feat in node_features:
            combined_feat = np.concatenate([node_feat, edge_agg])
            features.append(combined_feat)
        
        return np.array(features)


if __name__ == "__main__":
    # Test the evaluation framework with real data
    print("⚠️ Evaluation framework requires real NSL-KDD data and NS-3")
    print("Run the full training pipeline with: ./run_fixed_training.sh")
    
    # Create evaluation framework (will fail if NS-3 not available)
    try:
        evaluator = RealisticEvaluationFramework()
        print("✅ Evaluation framework initialized successfully")
    except Exception as e:
        print(f"❌ Setup required: {e}")
        print("Run: ./setup.sh to install all requirements")
