#!/usr/bin/env python3
"""
Curiosity Loop Feedback System (Proposal 2 Implementation)

Dynamically adjusts NS-3 simulation fidelity based on GNN attention uncertainty.
When entropy exceeds threshold, increases logging granularity and re-runs 
simulations in 'forensic mode' with enhanced detail.
"""

import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
import networkx as nx
from scipy.stats import spearmanr, pearsonr
import time

from .ns3_client import NS3Client, NetworkTopology, TrafficPattern
from .base import NS3ClientBase, SimReport
from .sim_cache import SimCache
from ..utils.common import GraphData, get_logger, attention_entropy


@dataclass
class FidelityLevel:
    """Simulation fidelity configuration."""
    name: str
    packet_tracing: bool = False
    qos_monitoring: bool = False
    debug_logging: bool = False
    time_granularity: str = "millisecond"  # microsecond, millisecond, second
    perturbation_count: int = 3
    traffic_patterns: int = 2
    
    @classmethod
    def low_fidelity(cls):
        return cls(
            name="low",
            packet_tracing=False,
            qos_monitoring=False,
            debug_logging=False,
            time_granularity="millisecond",
            perturbation_count=2,
            traffic_patterns=1
        )
    
    @classmethod
    def high_fidelity(cls):
        return cls(
            name="high",
            packet_tracing=True,
            qos_monitoring=True,
            debug_logging=True,
            time_granularity="microsecond",
            perturbation_count=5,
            traffic_patterns=4
        )
    
    @classmethod
    def forensic_mode(cls):
        return cls(
            name="forensic",
            packet_tracing=True,
            qos_monitoring=True,
            debug_logging=True,
            time_granularity="microsecond",
            perturbation_count=8,
            traffic_patterns=6
        )


@dataclass
class CuriosityIteration:
    """Results from one curiosity loop iteration."""
    iteration: int
    fidelity_level: FidelityLevel
    attention_entropy: float
    simulation_result: Optional[Dict[str, Any]]
    refined_attention: Optional[torch.Tensor]
    convergence_achieved: bool
    resource_cost: float


@dataclass
class CuriosityLoopResult:
    """Complete curiosity loop feedback."""
    iterations: List[CuriosityIteration]
    final_attention: torch.Tensor
    final_fidelity_loss: torch.Tensor
    final_sparsity_loss: torch.Tensor
    total_resource_cost: float
    convergence_achieved: bool
    metadata: Dict[str, Any]


class CuriosityLoopFeedback:
    """
    Curiosity Loop Feedback System implementing Proposal 2.
    
    Key Features:
    1. Adaptive Fidelity: Start low-fidelity, upgrade based on uncertainty
    2. Curiosity Loop: Iterative refinement until convergence
    3. Forensic Mode: Enhanced simulation detail for high uncertainty
    4. Resource Optimization: Smart allocation based on model needs
    """
    
    def __init__(self, 
                 ns3_client: NS3Client,
                 cache: SimCache,
                 uncertainty_threshold: float = 0.3,
                 high_uncertainty_threshold: float = 0.7,
                 forensic_threshold: float = 0.9,
                 top_k_edges: int = 5,
                 budget_per_epoch: int = 5,
                 max_curiosity_iterations: int = 3,
                 convergence_tolerance: float = 0.05,
                 max_nodes_per_simulation: int = 150):
        
        self.ns3_client = ns3_client
        self.cache = cache
        self.uncertainty_threshold = uncertainty_threshold
        self.high_uncertainty_threshold = high_uncertainty_threshold
        self.forensic_threshold = forensic_threshold
        self.top_k_edges = top_k_edges
        self.budget_per_epoch = budget_per_epoch
        self.max_curiosity_iterations = max_curiosity_iterations
        self.convergence_tolerance = convergence_tolerance
        self.max_nodes_per_simulation = max_nodes_per_simulation
        
        self.logger = get_logger("curiosity_loop")
        
        # Budget tracking with adaptive allocation (cost-based)
        self.simulations_this_epoch = 0
        self.current_epoch = -1
        
        # Resource allocation based on computational costs (Proposal 2 key feature)
        # Convert budget_per_epoch to computational budget (assume it's in cost units now)
        self.total_budget = float(budget_per_epoch)  # Total computational budget per epoch
        self.budget_used = 0.0  # Track actual computational cost consumed
        
        # Calculate reference costs for each fidelity level
        self.low_fidelity_cost = self._calculate_resource_cost(FidelityLevel.low_fidelity())
        self.high_fidelity_cost = self._calculate_resource_cost(FidelityLevel.high_fidelity())
        self.forensic_cost = self._calculate_resource_cost(FidelityLevel.forensic_mode())
        
        # Log the cost structure for transparency
        self.logger.info(f"Budget system initialized - Total: {self.total_budget:.1f}")
        self.logger.info(f"  Low fidelity cost: {self.low_fidelity_cost:.1f}")
        self.logger.info(f"  High fidelity cost: {self.high_fidelity_cost:.1f}")
        self.logger.info(f"  Forensic cost: {self.forensic_cost:.1f}")
        
        # Curiosity loop state
        self.uncertainty_history = []
        self.convergence_history = []
        
        self._rng = np.random.RandomState(42)
    
    def analyze_attention_uncertainty(self, attention_weights: torch.Tensor, 
                                   edge_index: torch.Tensor) -> Tuple[float, bool, str]:
        """
        Analyze attention uncertainty and determine required fidelity level.
        
        Returns:
            entropy: Attention entropy value
            needs_simulation: Whether simulation is needed
            fidelity_level: Required fidelity level ("low", "high", "forensic")
        """
        # Calculate attention entropy (simplified for edge-level attention)
        if attention_weights.numel() == 0:
            entropy = 0.0
        else:
            # Normalize attention weights to probabilities
            probs = torch.softmax(attention_weights, dim=0)
            # Calculate Shannon entropy
            log_probs = torch.log(probs + 1e-8)  # Add small epsilon to avoid log(0)
            entropy = -(probs * log_probs).sum().item()
        
        self.logger.debug(f"Attention entropy: {entropy:.4f}")
        self.logger.debug(f"Thresholds - Base: {self.uncertainty_threshold}, "
                         f"High: {self.high_uncertainty_threshold}, "
                         f"Forensic: {self.forensic_threshold}")
        
        # Determine fidelity level based on uncertainty
        if entropy < self.uncertainty_threshold:
            return entropy, False, "none"
        elif entropy < self.high_uncertainty_threshold:
            return entropy, True, "low"
        elif entropy < self.forensic_threshold:
            return entropy, True, "high"
        else:
            return entropy, True, "forensic"
    
    def check_budget_availability(self, fidelity_level: str) -> bool:
        """Check if budget is available for requested fidelity level."""
        required_cost = self._get_fidelity_cost(fidelity_level)
        remaining_budget = self.total_budget - self.budget_used
        
        available = remaining_budget >= required_cost
        
        if not available:
            self.logger.debug(f"Budget insufficient for {fidelity_level}: "
                            f"need {required_cost:.1f}, have {remaining_budget:.1f}")
        
        return available
    
    def consume_budget(self, fidelity_level: str) -> float:
        """Consume budget for the given fidelity level and return the cost."""
        cost = self._get_fidelity_cost(fidelity_level)
        self.budget_used += cost
        self.simulations_this_epoch += 1
        
        remaining = self.total_budget - self.budget_used
        self.logger.debug(f"Consumed {cost:.1f} budget for {fidelity_level} simulation. "
                         f"Remaining: {remaining:.1f}/{self.total_budget:.1f}")
        
        return cost
    
    def _get_fidelity_cost(self, fidelity_level: str) -> float:
        """Get the computational cost for a given fidelity level."""
        if fidelity_level == "low":
            return self.low_fidelity_cost
        elif fidelity_level == "high":
            return self.high_fidelity_cost
        elif fidelity_level == "forensic":
            return self.forensic_cost
        else:
            self.logger.warning(f"Unknown fidelity level: {fidelity_level}, using low cost")
            return self.low_fidelity_cost
    
    def reset_budget_for_epoch(self, epoch: int):
        """Reset budget for a new epoch."""
        self.current_epoch = epoch
        self.budget_used = 0.0
        self.simulations_this_epoch = 0
        
        self.logger.info(f"Budget reset for epoch {epoch}: {self.total_budget:.1f} cost units available")
    
    def get_budget_status(self) -> Dict[str, float]:
        """Get current budget status."""
        remaining = self.total_budget - self.budget_used
        return {
            "total_budget": self.total_budget,
            "budget_used": self.budget_used,
            "budget_remaining": remaining,
            "utilization_percent": (self.budget_used / self.total_budget) * 100,
            "simulations_count": self.simulations_this_epoch
        }
    
    def run_adaptive_simulation(self, graph: GraphData, attention_weights: torch.Tensor,
                              fidelity_level: str) -> Optional[Dict[str, Any]]:
        """
        Run simulation with specified fidelity level.
        
        Key Proposal 2 Feature: Adaptive fidelity based on uncertainty.
        """
        try:
            # Get fidelity configuration
            if fidelity_level == "low":
                fidelity = FidelityLevel.low_fidelity()
            elif fidelity_level == "high":
                fidelity = FidelityLevel.high_fidelity()
            elif fidelity_level == "forensic":
                fidelity = FidelityLevel.forensic_mode()
            else:
                raise ValueError(f"Unknown fidelity level: {fidelity_level}")
            
            self.logger.info(f"Running {fidelity.name} fidelity simulation")
            
            # Create topology
            topology = self._create_ns3_topology(graph)
            if not topology:
                return None
            
            # Generate traffic patterns based on fidelity
            traffic_patterns = self._generate_traffic_patterns(
                topology, count=fidelity.traffic_patterns
            )
            
            # Run baseline simulation
            baseline_kpis = self._run_baseline_simulation(topology, traffic_patterns, fidelity)
            if not baseline_kpis:
                return None
            
            # Run perturbation simulations
            perturbation_results = self._run_perturbation_simulations(
                graph, attention_weights, topology, traffic_patterns, 
                fidelity, count=fidelity.perturbation_count
            )
            
            # Calculate ground truth importance
            ground_truth_importance = self._calculate_ground_truth_importance(
                graph, perturbation_results, baseline_kpis
            )
            
            # Calculate alignment score
            alignment_score = self._calculate_attention_alignment(
                attention_weights, ground_truth_importance, graph.edge_index
            )
            
            result = {
                "fidelity_level": fidelity.name,
                "baseline_kpis": baseline_kpis,
                "perturbation_results": perturbation_results,
                "ground_truth_importance": ground_truth_importance,
                "alignment_score": alignment_score,
                "resource_cost": self._calculate_resource_cost(fidelity),
                "metadata": {
                    "packet_tracing": fidelity.packet_tracing,
                    "qos_monitoring": fidelity.qos_monitoring,
                    "debug_logging": fidelity.debug_logging,
                    "time_granularity": fidelity.time_granularity,
                    "perturbation_count": fidelity.perturbation_count,
                    "traffic_patterns": fidelity.traffic_patterns
                }
            }
            
            self.logger.info(f"Simulation completed - Alignment: {alignment_score:.4f}, "
                           f"Cost: {result['resource_cost']:.2f}")
            
            return result
            
        except Exception as e:
            self.logger.error(f"Simulation failed: {e}")
            return None
    
    def curiosity_loop_iteration(self, graph: GraphData, attention_weights: torch.Tensor,
                               iteration: int) -> CuriosityIteration:
        """
        Single curiosity loop iteration with adaptive fidelity.
        
        Core Proposal 2 Feature: Iterative refinement based on uncertainty.
        """
        # Analyze current uncertainty
        entropy, needs_sim, fidelity_level = self.analyze_attention_uncertainty(
            attention_weights, graph.edge_index
        )
        
        # Check if simulation is needed and budget available
        if not needs_sim or not self.check_budget_availability(fidelity_level):
            return CuriosityIteration(
                iteration=iteration,
                fidelity_level=FidelityLevel.low_fidelity(),
                attention_entropy=entropy,
                simulation_result=None,
                refined_attention=attention_weights,
                convergence_achieved=True,
                resource_cost=0.0
            )
        
        # Run adaptive simulation
        simulation_result = self.run_adaptive_simulation(graph, attention_weights, fidelity_level)
        actual_cost = self.consume_budget(fidelity_level)
        
        if not simulation_result:
            return CuriosityIteration(
                iteration=iteration,
                fidelity_level=FidelityLevel.low_fidelity(),
                attention_entropy=entropy,
                simulation_result=None,
                refined_attention=attention_weights,
                convergence_achieved=True,
                resource_cost=0.0
            )
        
        # Refine attention based on simulation feedback
        refined_attention = self._refine_attention_weights(
            attention_weights, simulation_result, graph.edge_index
        )
        
        # Check convergence
        convergence_achieved = self._check_convergence(
            attention_weights, refined_attention
        )
        
        # Get fidelity config for metadata
        if fidelity_level == "low":
            fidelity_config = FidelityLevel.low_fidelity()
        elif fidelity_level == "high":
            fidelity_config = FidelityLevel.high_fidelity()
        else:
            fidelity_config = FidelityLevel.forensic_mode()
        
        return CuriosityIteration(
            iteration=iteration,
            fidelity_level=fidelity_config,
            attention_entropy=entropy,
            simulation_result=simulation_result,
            refined_attention=refined_attention,
            convergence_achieved=convergence_achieved,
            resource_cost=actual_cost
        )
    
    def analyze_and_simulate(self, attention_weights: torch.Tensor, 
                           edge_index: torch.Tensor, 
                           graph: GraphData,
                           epoch: Optional[int] = None) -> Optional[CuriosityLoopResult]:
        """
        Main curiosity loop implementation (Proposal 2 core algorithm).
        
        Dynamically adjusts simulation fidelity and iteratively refines
        attention until convergence or budget exhaustion.
        """
        # Reset epoch budget if new epoch
        if epoch is not None and epoch != self.current_epoch:
            self.current_epoch = epoch
            self.simulations_this_epoch = 0
            self.budget_used = 0.0
            self.logger.info(f"Starting curiosity loop for epoch {epoch}")
        
        # Initial uncertainty analysis
        initial_entropy, needs_simulation, _ = self.analyze_attention_uncertainty(
            attention_weights, edge_index
        )
        
        if not needs_simulation:
            self.logger.debug(f"Low uncertainty ({initial_entropy:.4f}), skipping simulation")
            return None
        
        # Curiosity loop iterations
        iterations = []
        current_attention = attention_weights.clone()
        total_cost = 0.0
        
        for i in range(self.max_curiosity_iterations):
            iteration_result = self.curiosity_loop_iteration(graph, current_attention, i)
            iterations.append(iteration_result)
            total_cost += iteration_result.resource_cost
            
            # Update attention for next iteration
            if iteration_result.refined_attention is not None:
                current_attention = iteration_result.refined_attention
            
            # Check stopping conditions
            if iteration_result.convergence_achieved:
                self.logger.info(f"Curiosity loop converged after {i+1} iterations")
                break
            
            if self.simulations_this_epoch >= self.budget_per_epoch:
                self.logger.info(f"Budget exhausted after {i+1} iterations")
                break
        
        # Generate final losses
        final_iteration = iterations[-1]
        if final_iteration.simulation_result:
            fidelity_loss, sparsity_loss = self._generate_training_losses(
                current_attention, final_iteration.simulation_result, graph.edge_index
            )
        else:
            fidelity_loss = torch.tensor(0.0)
            sparsity_loss = torch.tensor(0.0)
        
        result = CuriosityLoopResult(
            iterations=iterations,
            final_attention=current_attention,
            final_fidelity_loss=fidelity_loss,
            final_sparsity_loss=sparsity_loss,
            total_resource_cost=total_cost,
            convergence_achieved=final_iteration.convergence_achieved,
            metadata={
                "initial_entropy": initial_entropy,
                "final_entropy": iterations[-1].attention_entropy,
                "total_iterations": len(iterations),
                "budget_used": {
                    "budget_used": self.budget_used,
                    "budget_total": self.total_budget,
                    "budget_remaining": self.total_budget - self.budget_used,
                    "utilization_percent": (self.budget_used / self.total_budget) * 100,
                    "simulations_count": self.simulations_this_epoch
                }
            }
        )
        
        self.logger.info(f"Curiosity loop completed: {len(iterations)} iterations, "
                        f"cost: {total_cost:.2f}, converged: {result.convergence_achieved}")
        
        return result
    
    def _refine_attention_weights(self, attention_weights: torch.Tensor,
                                simulation_result: Dict[str, Any],
                                edge_index: torch.Tensor) -> torch.Tensor:
        """Refine attention weights based on simulation feedback."""
        ground_truth = simulation_result["ground_truth_importance"]
        
        # Blend original attention with ground truth (curiosity-driven refinement)
        alpha = 0.3  # Refinement strength
        refined_attention = (1 - alpha) * attention_weights + alpha * ground_truth
        
        # Normalize
        refined_attention = F.softmax(refined_attention, dim=0)
        
        return refined_attention
    
    def _check_convergence(self, old_attention: torch.Tensor, 
                         new_attention: torch.Tensor) -> bool:
        """Check if attention has converged."""
        if old_attention.shape != new_attention.shape:
            return False
        
        # Calculate change in attention
        attention_change = torch.norm(new_attention - old_attention).item()
        
        converged = attention_change < self.convergence_tolerance
        self.logger.debug(f"Attention change: {attention_change:.6f}, "
                         f"tolerance: {self.convergence_tolerance}, converged: {converged}")
        
        return converged
    
    def _calculate_resource_cost(self, fidelity: FidelityLevel) -> float:
        """Calculate resource cost for fidelity level."""
        base_cost = 1.0
        
        if fidelity.packet_tracing:
            base_cost *= 2.0
        if fidelity.qos_monitoring:
            base_cost *= 1.5
        if fidelity.debug_logging:
            base_cost *= 1.3
        if fidelity.time_granularity == "microsecond":
            base_cost *= 2.0
        
        base_cost *= (fidelity.perturbation_count / 3.0)  # Scale with perturbations
        base_cost *= (fidelity.traffic_patterns / 2.0)    # Scale with traffic
        
        return base_cost
    
    # Import helper methods from enhanced_feedback.py
    def _create_ns3_topology(self, graph: GraphData) -> Optional[NetworkTopology]:
        """Create NS-3 topology from graph (simplified version)."""
        try:
            # Select subset of nodes for simulation
            num_nodes = min(graph.x.shape[0], 20)  # Limit for efficiency
            selected_nodes = self._rng.choice(graph.x.shape[0], size=num_nodes, replace=False)
            
            # Create node mapping
            node_id_map = {old_id: new_id for new_id, old_id in enumerate(selected_nodes)}
            
            # Extract edges between selected nodes
            edge_index = graph.edge_index
            topology_edges = []
            
            for i in range(edge_index.shape[1]):
                src, dst = edge_index[0, i].item(), edge_index[1, i].item()
                if src in node_id_map and dst in node_id_map:
                    topology_edges.append((node_id_map[src], node_id_map[dst]))
            
            # Create node specifications
            nodes = [{"id": i, "type": "default"} for i in range(num_nodes)]
            
            # Create link specifications
            links = [{"src": src, "dst": dst, "type": "p2p"} for src, dst in topology_edges]
            
            return NetworkTopology(
                nodes=nodes,
                links=links,
                node_id_map=node_id_map
            )
            
        except Exception as e:
            self.logger.error(f"Failed to create topology: {e}")
            return None
    
    def _generate_traffic_patterns(self, topology: NetworkTopology, count: int = 2) -> List[TrafficPattern]:
        """Generate traffic patterns for simulation."""
        patterns = []
        
        for _ in range(count):
            # Random source and destination
            num_nodes = len(topology.nodes)
            src = self._rng.randint(0, num_nodes)
            dst = self._rng.randint(0, num_nodes)
            
            if src != dst:
                patterns.append(TrafficPattern(
                    src_node=src,
                    dst_node=dst,
                    application="OnOff",
                    data_rate="1Mbps",
                    packet_size=1024,
                    start_time=0.0,
                    stop_time=1.0
                ))
        
        return patterns
    
    def _run_baseline_simulation(self, topology: NetworkTopology, 
                               traffic_patterns: List[TrafficPattern],
                               fidelity: FidelityLevel) -> Optional[Dict[str, float]]:
        """Run baseline simulation."""
        try:
            # Convert NetworkTopology object to dictionary format expected by NS3Client
            topology_dict = {
                "nodes": topology.nodes,
                "links": topology.links,
                "routing": topology.routing,
                "duration": topology.duration
            }
            
            # Convert TrafficPattern objects to dictionary format
            traffic_dict = []
            for pattern in traffic_patterns:
                traffic_dict.append({
                    "src_node": pattern.src_node,
                    "dst_node": pattern.dst_node,
                    "application": pattern.application,
                    "data_rate": pattern.data_rate,
                    "packet_size": pattern.packet_size,
                    "start_time": pattern.start_time,
                    "stop_time": pattern.stop_time
                })
            
            # Create scenario specification
            scenario_spec = {
                "topology": topology_dict,
                "traffic": traffic_dict,
                "fidelity": fidelity.name
            }
            result = self.ns3_client.run_scenario(scenario_spec)
            
            # Check if simulation was skipped due to complexity
            if result.get("simulation_metadata", {}).get("skipped_reason"):
                self.logger.info(f"Baseline simulation skipped: {result['simulation_metadata']['skipped_reason']}")
                return {
                    "throughput": 0.0,
                    "latency": 0.0,
                    "packet_loss": 0.0,
                    "skipped": True
                }
            else:
                return {
                    "throughput": result.get("throughput_mbps", 0.0),
                    "latency": result.get("latency_ms", 0.0),
                    "packet_loss": result.get("drop_rate", 0.0),
                    "skipped": False
                }
        except Exception as e:
            self.logger.error(f"Baseline simulation failed: {e}")
            return None
    
    def _run_perturbation_simulations(self, graph: GraphData, attention_weights: torch.Tensor,
                                    topology: NetworkTopology, traffic_patterns: List[TrafficPattern],
                                    fidelity: FidelityLevel, count: int = 3) -> List[Dict[str, Any]]:
        """Run perturbation simulations."""
        results = []
        
        # Get top-k edges by attention
        top_k_indices = torch.topk(attention_weights, k=min(count, len(attention_weights))).indices
        
        for i, edge_idx in enumerate(top_k_indices):
            try:
                # Create perturbed topology (remove edge)
                perturbed_topology = self._create_perturbed_topology(topology, edge_idx.item())
                
                # Convert perturbed topology to dictionary format
                perturbed_topology_dict = {
                    "nodes": perturbed_topology.nodes,
                    "links": perturbed_topology.links,
                    "routing": perturbed_topology.routing,
                    "duration": perturbed_topology.duration
                }
                
                # Convert TrafficPattern objects to dictionary format
                traffic_dict = []
                for pattern in traffic_patterns:
                    traffic_dict.append({
                        "src_node": pattern.src_node,
                        "dst_node": pattern.dst_node,
                        "application": pattern.application,
                        "data_rate": pattern.data_rate,
                        "packet_size": pattern.packet_size,
                        "start_time": pattern.start_time,
                        "stop_time": pattern.stop_time
                    })
                
                # Run simulation
                scenario_spec = {
                    "topology": perturbed_topology_dict,
                    "traffic": traffic_dict,
                    "fidelity": fidelity.name
                }
                result = self.ns3_client.run_scenario(scenario_spec)
                
                # Check if simulation was skipped due to complexity
                if result.get("simulation_metadata", {}).get("skipped_reason"):
                    self.logger.info(f"Simulation skipped for edge {edge_idx.item()}: {result['simulation_metadata']['skipped_reason']}")
                    # Use zero impact for skipped simulations
                    results.append({
                        "edge_index": edge_idx.item(),
                        "kpis": {
                            "throughput": 0.0,
                            "latency": 0.0,
                            "packet_loss": 0.0
                        },
                        "skipped": True
                    })
                else:
                    results.append({
                        "edge_index": edge_idx.item(),
                        "kpis": {
                            "throughput": result.get("throughput_mbps", 0.0),
                            "latency": result.get("latency_ms", 0.0),
                            "packet_loss": result.get("drop_rate", 0.0)
                        },
                        "skipped": False
                    })
                
            except Exception as e:
                self.logger.warning(f"Perturbation simulation {i} failed: {e}")
        
        return results
    
    def _create_perturbed_topology(self, topology: NetworkTopology, edge_idx: int) -> NetworkTopology:
        """Create topology with specified edge removed."""
        # Simple implementation - remove link if it exists
        perturbed_links = [link for i, link in enumerate(topology.links) if i != edge_idx]
        
        return NetworkTopology(
            nodes=topology.nodes,
            links=perturbed_links,
            node_id_map=topology.node_id_map
        )
    
    def _calculate_ground_truth_importance(self, graph: GraphData, 
                                         perturbation_results: List[Dict[str, Any]],
                                         baseline_kpis: Dict[str, float]) -> torch.Tensor:
        """Calculate ground truth edge importance from simulation results."""
        importance_scores = torch.zeros(graph.edge_index.shape[1])
        
        for result in perturbation_results:
            edge_idx = result["edge_index"]
            perturbed_kpis = result["kpis"]
            
            # Calculate KPI impact (higher impact = more important)
            throughput_impact = abs(baseline_kpis["throughput"] - perturbed_kpis["throughput"])
            latency_impact = abs(perturbed_kpis["latency"] - baseline_kpis["latency"])
            loss_impact = abs(perturbed_kpis["packet_loss"] - baseline_kpis["packet_loss"])
            
            # Combine impacts (normalize by baseline to get relative impact)
            total_impact = (
                throughput_impact / (baseline_kpis["throughput"] + 1e-6) +
                latency_impact / (baseline_kpis["latency"] + 1e-6) +
                loss_impact / (baseline_kpis["packet_loss"] + 1e-6)
            )
            
            if edge_idx < len(importance_scores):
                importance_scores[edge_idx] = total_impact
        
        # Handle case where all impacts are zero (highly redundant network)
        if torch.var(importance_scores) == 0.0:
            self.logger.debug("All edge perturbations have zero impact - using topology-based importance")
            # Fallback: use degree-based importance as proxy
            edge_index = graph.edge_index
            node_degrees = torch.zeros(graph.x.shape[0])
            for i in range(edge_index.shape[1]):
                src, dst = edge_index[0, i], edge_index[1, i]
                node_degrees[src] += 1
                node_degrees[dst] += 1
            
            # Edge importance = average degree of connected nodes
            for i in range(edge_index.shape[1]):
                src, dst = edge_index[0, i], edge_index[1, i]
                importance_scores[i] = (node_degrees[src] + node_degrees[dst]) / 2.0
            
            # Normalize degree-based scores
            if torch.max(importance_scores) > 0:
                importance_scores = importance_scores / torch.max(importance_scores)
        else:
            # Normalize simulation-based scores to [0, 1]
            if importance_scores.sum() > 0:
                importance_scores = importance_scores / importance_scores.sum()
        
        return importance_scores
    
    def _calculate_attention_alignment(self, attention_weights: torch.Tensor,
                                     ground_truth_importance: torch.Tensor,
                                     edge_index: torch.Tensor) -> float:
        """Calculate alignment between attention and ground truth."""
        if len(attention_weights) != len(ground_truth_importance):
            return 0.0
        
        # Calculate correlation
        try:
            attn_array = attention_weights.detach().cpu().numpy()
            gt_array = ground_truth_importance.detach().cpu().numpy()
            
            # Check for constant arrays (no variance)
            if np.var(attn_array) == 0 or np.var(gt_array) == 0:
                self.logger.debug(f"Constant array detected - Attention var: {np.var(attn_array):.6f}, "
                                f"GT var: {np.var(gt_array):.6f}")
                return 0.0
            
            correlation, _ = pearsonr(attn_array, gt_array)
            return float(correlation) if not np.isnan(correlation) else 0.0
        except Exception as e:
            self.logger.debug(f"Correlation calculation failed: {e}")
            return 0.0
    
    def _generate_training_losses(self, attention_weights: torch.Tensor,
                                simulation_result: Dict[str, Any],
                                edge_index: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate fidelity and sparsity losses for training."""
        ground_truth = simulation_result["ground_truth_importance"]
        
        # Fidelity loss: MSE between attention and ground truth
        fidelity_loss = F.mse_loss(attention_weights, ground_truth)
        
        # Sparsity loss: Encourage sparse attention
        sparsity_loss = torch.mean(attention_weights ** 2)
        
        return fidelity_loss, sparsity_loss


# Alias for backward compatibility
EnhancedSimulationFeedback = CuriosityLoopFeedback
