"""
Edge tracking utility for SPOTLIGHT anomaly detection integration.
Tracks all edges (normal + attack) during temporal graph testing.
"""

import json
import torch
import torch.nn.functional as F
import numpy as np
import logging
import os
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict
from datetime import datetime
from collections import Counter


class EdgeTracker:
    """
    Pure edge tracking class for data collection and export utilities.
    Stores all edges (normal + attack) during temporal graph testing.
    """

    def __init__(self, dataset_name: str, attack_type: str = "none"):
        self.dataset_name = dataset_name
        self.attack_type = attack_type
        self.edges = []  # List of edge dictionaries
        self.attack_batch_ids = set()  # Track batch IDs when attacks occurred
        self.total_normal_edges = 0
        self.total_attack_edges = 0
        self.current_batch_id = 0  # Track current batch ID for spotlight

        # Message storage for analysis
        self.normal_messages = []
        self.attack_messages = []

        # Create output directory
        self.output_dir = "spotlight_data"
        os.makedirs(self.output_dir, exist_ok=True)

        logging.info(
            f"EdgeTracker initialized for {dataset_name} with attack: {attack_type}"
        )

    def set_current_batch_id(self, batch_id: int):
        """Set the current batch ID for spotlight analysis."""
        self.current_batch_id = batch_id

    def add_normal_edges(
        self,
        src: torch.Tensor,
        dst: torch.Tensor,
        t: torch.Tensor,
        msg: torch.Tensor = None,
        weights: torch.Tensor = None,
    ):
        """
        Add normal (ground truth) edges to tracking.

        Args:
            src: Source nodes
            dst: Destination nodes
            t: Timestamps
            msg: Edge messages (optional)
            weights: Edge weights (optional, defaults to 1.0)
        """
        if weights is None:
            weights = torch.ones(len(src))

        # Store messages for anomaly detection
        if msg is not None and len(msg) > 0:
            self.normal_messages.append(msg.detach().cpu())

        for i in range(len(src)):
            edge_info = {
                "src": src[i].item(),
                "dst": dst[i].item(),
                "batch_id": self.current_batch_id,  # Use batch ID instead of timestamp
                "weight": weights[i].item(),
                "is_attack": False,
                "attack_type": None,
                "has_message": msg is not None,
            }
            self.edges.append(edge_info)
            self.total_normal_edges += 1

    def add_attack_edges(
        self,
        src: torch.Tensor,
        dst: torch.Tensor,
        t: torch.Tensor,
        msg: torch.Tensor = None,
        weights: torch.Tensor = None,
        attack_type: str = None,
    ):
        """
        Add attack (adversarial) edges to tracking.

        Args:
            src: Source nodes
            dst: Destination nodes
            t: Timestamps
            msg: Edge messages (optional)
            weights: Edge weights (optional, defaults to 1.0)
            attack_type: Type of attack (memstranding, grbcd, etc.)
        """
        if weights is None:
            weights = torch.ones(len(src))

        attack_type = attack_type or self.attack_type

        # Store attack messages for anomaly detection
        if msg is not None and len(msg) > 0:
            self.attack_messages.append(msg.detach().cpu())

        for i in range(len(src)):
            batch_id = self.current_batch_id  # Use current batch ID for attack edges
            edge_info = {
                "src": src[i].item(),
                "dst": dst[i].item(),
                "batch_id": batch_id,  # Use batch ID instead of timestamp
                "weight": weights[i].item(),
                "is_attack": True,
                "attack_type": attack_type,
                "has_message": msg is not None,
            }
            self.edges.append(edge_info)
            self.attack_batch_ids.add(batch_id)
            self.total_attack_edges += 1

    def export_messages(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Export stored messages for analysis.

        Returns:
            Tuple of (normal_messages, attack_messages) as concatenated tensors
        """
        normal_msgs = (
            torch.cat(self.normal_messages, dim=0)
            if self.normal_messages
            else torch.empty(0, 0)
        )
        attack_msgs = (
            torch.cat(self.attack_messages, dim=0)
            if self.attack_messages
            else torch.empty(0, 0)
        )

        logging.info(
            f"Exported  normal messages ({normal_msgs.shape}), {len(attack_msgs)} attack messages"
        )
        return normal_msgs, attack_msgs

    def get_attack_batch_ids(self) -> List[int]:
        """Get sorted list of attack batch IDs."""
        return sorted(list(self.attack_batch_ids))

    def get_batch_statistics(self) -> Dict:
        """Get statistics about edges per batch."""
        batch_stats = {}

        for edge in self.edges:
            batch_id = edge["batch_id"]
            if batch_id not in batch_stats:
                batch_stats[batch_id] = {
                    "total_edges": 0,
                    "normal_edges": 0,
                    "attack_edges": 0,
                    "attack_types": set(),
                }

            batch_stats[batch_id]["total_edges"] += 1
            if edge["is_attack"]:
                batch_stats[batch_id]["attack_edges"] += 1
                if edge["attack_type"]:
                    batch_stats[batch_id]["attack_types"].add(edge["attack_type"])
            else:
                batch_stats[batch_id]["normal_edges"] += 1

        # Convert sets to lists for JSON serialization
        for batch_id in batch_stats:
            batch_stats[batch_id]["attack_types"] = list(
                batch_stats[batch_id]["attack_types"]
            )

        return batch_stats

    def export_for_spotlight(self, filename: str = None) -> str:
        """
        Export tracked edges in SPOTLIGHT format.

        Format: batch_id source_node destination_node weight

        Args:
            filename: Output filename (auto-generated if None)

        Returns:
            Path to exported file
        """
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{self.dataset_name}_{self.attack_type}_{timestamp}.txt"

        filepath = os.path.join(self.output_dir, filename)

        # Sort edges by batch_id for temporal consistency
        sorted_edges = sorted(self.edges, key=lambda x: x["batch_id"])

        with open(filepath, "w") as f:
            for edge in sorted_edges:
                f.write(
                    f"{edge['batch_id']} {edge['src']} {edge['dst']} {edge['weight']}\n"
                )

        logging.info(
            f"Exported {len(sorted_edges)} edges to {os.path.abspath(filepath)}, path exists: {os.path.exists(filepath)}"
        )
        logging.info(
            f"Normal edges: {self.total_normal_edges}, Attack edges: {self.total_attack_edges}"
        )

        return filepath

    def get_attack_summary(self) -> Dict:
        """
        Get summary of tracked attacks.

        Returns:
            Dictionary with attack statistics
        """
        attack_edges = [e for e in self.edges if e["is_attack"]]

        return {
            "total_edges": len(self.edges),
            "normal_edges": self.total_normal_edges,
            "attack_edges": self.total_attack_edges,
            "attack_batch_ids": sorted(list(self.attack_batch_ids)),
            "attack_percentage": (
                (self.total_attack_edges / len(self.edges)) * 100 if self.edges else 0
            ),
            "unique_attack_batch_ids": len(self.attack_batch_ids),
        }

    def export_attack_metadata(self, filename: str = None) -> str:
        """
        Export metadata about attacks for analysis.

        Args:
            filename: Output filename (auto-generated if None)

        Returns:
            Path to exported metadata file
        """
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = (
                f"{self.dataset_name}_{self.attack_type}_metadata_{timestamp}.json"
            )

        filepath = os.path.join(self.output_dir, filename)

        metadata = self.get_attack_summary()
        metadata["dataset_name"] = self.dataset_name
        metadata["attack_type"] = self.attack_type

        with open(filepath, "w") as f:
            json.dump(metadata, f, indent=2)

        logging.info(f"Exported attack metadata to {filepath}")
        return filepath

    def _convert_numpy_types(self, obj):
        """Convert numpy types to Python native types for JSON serialization."""
        import numpy as np

        if isinstance(obj, dict):
            return {key: self._convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_numpy_types(item) for item in obj]
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj


def run_spotlight_detection(
    edgelist_file: str,
    dataset_name: str,
    attack_batch_ids: List[int],
    K: int = 50,
    is_bipartite: bool = False,
) -> Dict:
    """
    Run SPOTLIGHT detection on exported edgelist and compute detection metrics.

    Args:
        edgelist_file: Path to temporal edgelist file
        dataset_name: Name of dataset
        attack_batch_ids: List of batch IDs when attacks occurred
        K: Number of SPOTLIGHT sketches
        is_bipartite: Whether the graph is bipartite

    Returns:
        Dictionary with detection results
    """
    try:
        # Import SPOTLIGHT (add path if needed)
        import sys

        from modules.spotlight import run_SPOTLIGHT
        from modules.spotlight import metrics as spotlight_metrics

        logging.info(
            f"Running SPOTLIGHT detection on {edgelist_file} (bipartite: {is_bipartite})"
        )

        # Run SPOTLIGHT detection
        outliers, detection_time, scores = run_SPOTLIGHT(
            edgelist_file, K=K, use_rrcf=True, is_bipartite=is_bipartite
        )

        # Compute detection metrics
        if attack_batch_ids:
            accuracy = spotlight_metrics.compute_accuracy(outliers, attack_batch_ids)
            precision, recall, f1 = spotlight_metrics.compute_precision_recall_f1(
                outliers, attack_batch_ids, total_timestamps=1000  # Approximate
            )
            detection_delay = spotlight_metrics.compute_detection_delay(
                outliers, attack_batch_ids
            )
        else:
            accuracy = precision = recall = f1 = detection_delay = 0.0

        # Check if attacks were detected
        detected_attacks = set(outliers).intersection(set(attack_batch_ids))

        results = {
            "dataset_name": dataset_name,
            "total_outliers_detected": len(outliers),
            "attack_batch_ids": attack_batch_ids,
            "detected_outliers": list(outliers),
            "detected_attack_batch_ids": list(detected_attacks),
            "detection_accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1,
            "detection_delay": detection_delay,
            "detection_time": detection_time,
            "attack_detection_rate": (
                len(detected_attacks) / len(attack_batch_ids) if attack_batch_ids else 0
            ),
            "anomaly_scores": (
                scores.tolist() if hasattr(scores, "tolist") else list(scores)
            ),
        }

        logging.info(f"SPOTLIGHT Detection Results:")
        logging.info(f"  Accuracy: {accuracy:.3f}")
        logging.info(f"  Attack Detection Rate: {results['attack_detection_rate']:.3f}")
        logging.info(
            f"  Detected {len(detected_attacks)}/{len(attack_batch_ids)} attack batch IDs"
        )

        return results

    except Exception as e:
        import traceback

        logging.error(
            f"SPOTLIGHT detection failed: {e}\nTraceback:\n{traceback.format_exc()}"
        )
        return {
            "error": str(e),
            "detection_accuracy": 0.0,
            "attack_detection_rate": 0.0,
        }


class UnifiedAnomalyDetector:
    """
    Unified anomaly detection class that runs both SPOTLIGHT (structural)
    and feature-based anomaly detection on EdgeTracker data.
    """

    def __init__(self, feature_threshold: float = 2.5):
        self.feature_threshold = feature_threshold
        self.is_fitted = False
        self.mean = None
        self.std = None
        self.cov_matrix = None

    def analyze(
        self,
        edge_tracker: EdgeTracker,
        edgelist_file: str = None,
        attack_timestamps: List[int] = None,
        generate_plot: bool = True,
        is_bipartite: bool = False,
    ) -> Dict:
        """
        Run unified anomaly detection analysis on EdgeTracker data.

        Args:
            edge_tracker: EdgeTracker instance with collected data
            edgelist_file: Optional path to existing edgelist file
            attack_timestamps: List of ground truth attack timestamps for evaluation
            generate_plot: Whether to generate and save visualization plots
            is_bipartite: Whether the graph is bipartite

        Returns:
            Dictionary with comprehensive anomaly detection results
        """
        logging.info(
            f"Starting unified anomaly detection for {edge_tracker.dataset_name}"
        )

        results = {
            "dataset_name": edge_tracker.dataset_name,
            "attack_type": edge_tracker.attack_type,
            "timestamp": datetime.now().isoformat(),
            "edge_summary": edge_tracker.get_attack_summary(),
        }

        # 1. Run structural (SPOTLIGHT) anomaly detection
        results["structural_detection"] = self._detect_structural_anomalies(
            edge_tracker, is_bipartite=is_bipartite
        )

        # 2. Run feature-based anomaly detection
        results["feature_detection"] = self._detect_feature_anomalies(edge_tracker)

        # 3. Combine detection scores (structural and feature with equal weights)
        results["combined_analysis"] = self._combine_detection_scores(
            results, attack_timestamps
        )

        # 4. Export results
        results_file = self._export_results(edge_tracker, results)
        results["results_file"] = results_file

        # 5. Generate visualization plot (optional)
        if generate_plot:
            plot_file = self.plot_anomaly_detection_results(
                edge_tracker, results, save_plot=True
            )
            if plot_file:
                results["plot_file"] = plot_file

        logging.info(
            f"Unified anomaly detection completed for {edge_tracker.dataset_name}"
        )
        return results

    def _detect_structural_anomalies(
        self, edge_tracker: EdgeTracker, is_bipartite: bool = False
    ) -> Dict:
        """Run SPOTLIGHT detection on edge structure."""
        try:
            # Export edges for SPOTLIGHT
            edgelist_file = edge_tracker.export_for_spotlight()
            attack_batch_ids = edge_tracker.get_attack_batch_ids()

            if not attack_batch_ids:
                return {
                    "status": "no_attacks_detected",
                    "attack_detection_rate": 0.0,
                    "message": "No attacks to detect",
                }

            # Run SPOTLIGHT detection
            spotlight_results = run_spotlight_detection(
                edgelist_file,
                edge_tracker.dataset_name,
                attack_batch_ids,
                is_bipartite=is_bipartite,
            )
            spotlight_results["status"] = "success"
            return spotlight_results

        except Exception as e:
            logging.error(f"Structural anomaly detection failed: {e}")
            return {"status": "error", "error": str(e), "attack_detection_rate": 0.0}

    def _detect_feature_anomalies(self, edge_tracker: EdgeTracker) -> Dict:
        """Run statistical detection on edge features."""
        try:
            # Get messages from edge tracker
            normal_msgs, attack_msgs = edge_tracker.export_messages()

            if len(normal_msgs) == 0:
                return {
                    "status": "no_normal_messages",
                    "message": "No normal messages available for baseline",
                    "feature_anomaly_rate": 0.0,
                }

            if len(attack_msgs) == 0:
                return {
                    "status": "no_attack_messages",
                    "message": "No attack messages to analyze",
                    "feature_anomaly_rate": 0.0,
                }

            # Fit detector on normal messages
            self._fit_feature_detector(normal_msgs)

            # Detect anomalies in attack messages
            anomaly_scores = self._compute_anomaly_scores(attack_msgs)
            anomaly_flags = self._get_anomaly_flags(anomaly_scores)

            feature_anomaly_rate = float(torch.mean(anomaly_flags.float()))

            # Enhanced range-based analysis
            range_analysis = self._analyze_feature_ranges(normal_msgs, attack_msgs)

            results = {
                "status": "success",
                "normal_message_count": len(normal_msgs),
                "attack_message_count": len(attack_msgs),
                "feature_dimension": (
                    normal_msgs.shape[1] if len(normal_msgs.shape) > 1 else 0
                ),
                "anomaly_scores": {k: v.tolist() for k, v in anomaly_scores.items()},
                "anomaly_flags": anomaly_flags.tolist(),
                "feature_anomaly_rate": feature_anomaly_rate,
                "anomalous_messages": int(torch.sum(anomaly_flags)),
                "anomaly_percentage": feature_anomaly_rate * 100,
                "range_analysis": range_analysis,
            }

            logging.info(
                f"Feature anomaly detection: {results['anomaly_percentage']:.1f}% of attack messages flagged as anomalous"
            )
            logging.info(
                f"Range analysis: {range_analysis['attack_edges_beyond_range']}/{len(attack_msgs)} "
                f"({range_analysis['attack_edges_beyond_range_pct']:.1f}%) attack edges beyond normal range"
            )
            return results

        except Exception as e:
            logging.error(f"Feature anomaly detection failed: {e}")
            return {"status": "error", "error": str(e), "feature_anomaly_rate": 0.0}

    def _fit_feature_detector(self, normal_messages: torch.Tensor):
        """Fit statistical detector on normal messages."""
        if len(normal_messages) == 0:
            return

        self.mean = torch.mean(normal_messages, dim=0)
        self.std = torch.std(normal_messages, dim=0)

        # Compute covariance matrix with regularization
        if len(normal_messages) > 1:
            self.cov_matrix = torch.cov(normal_messages.T)
            # Add regularization for numerical stability
            feature_dim = normal_messages.shape[1]
            self.cov_matrix += torch.eye(feature_dim) * 1e-6
        else:
            feature_dim = normal_messages.shape[1]
            self.cov_matrix = torch.eye(feature_dim)

        self.is_fitted = True
        logging.info(
            f"Feature detector fitted on {len(normal_messages)} normal messages"
        )

    def _analyze_feature_ranges(
        self, normal_msgs: torch.Tensor, attack_msgs: torch.Tensor
    ) -> Dict:
        """
        Analyze how attack features deviate from normal feature ranges.

        Args:
            normal_msgs: Normal message features [N_normal, D]
            attack_msgs: Attack message features [N_attack, D]

        Returns:
            Dictionary with range analysis results
        """
        if len(normal_msgs) == 0 or len(attack_msgs) == 0:
            return {
                "status": "insufficient_data",
                "attack_edges_beyond_range": 0,
                "attack_edges_beyond_range_pct": 0.0,
                "avg_dimensions_beyond_range": 0.0,
                "max_deviation_extent": 0.0,
                "avg_deviation_extent": 0.0,
            }

        # Compute min/max for each dimension from normal messages
        normal_min = torch.min(normal_msgs, dim=0)[0]  # [D]
        normal_max = torch.max(normal_msgs, dim=0)[0]  # [D]

        # Check which attack features are beyond normal range
        # attack_msgs: [N_attack, D], normal_min/normal_max: [D]
        beyond_min = attack_msgs < normal_min.unsqueeze(0)  # [N_attack, D]
        beyond_max = attack_msgs > normal_max.unsqueeze(0)  # [N_attack, D]
        beyond_range = beyond_min | beyond_max  # [N_attack, D]

        # Count attack edges that have at least one dimension beyond range
        attack_edges_beyond_range = torch.any(beyond_range, dim=1).sum().item()
        attack_edges_beyond_range_pct = (
            attack_edges_beyond_range / len(attack_msgs)
        ) * 100

        # Average number of dimensions beyond range per attack edge
        dimensions_beyond_range_per_edge = torch.sum(
            beyond_range, dim=1
        ).float()  # [N_attack]
        avg_dimensions_beyond_range = torch.mean(
            dimensions_beyond_range_per_edge
        ).item()

        # Compute extent of deviation for dimensions beyond range
        deviation_extents = []
        for i in range(len(attack_msgs)):
            edge_deviations = []
            for j in range(attack_msgs.shape[1]):
                if beyond_range[i, j]:
                    if beyond_min[i, j]:
                        # Below minimum
                        extent = (normal_min[j] - attack_msgs[i, j]).item()
                    else:  # beyond_max[i, j]
                        # Above maximum
                        extent = (attack_msgs[i, j] - normal_max[j]).item()
                    edge_deviations.append(extent)

            if edge_deviations:
                deviation_extents.append(
                    max(edge_deviations)
                )  # Max deviation for this edge

        max_deviation_extent = max(deviation_extents) if deviation_extents else 0.0
        avg_deviation_extent = (
            sum(deviation_extents) / len(deviation_extents)
            if deviation_extents
            else 0.0
        )

        # Additional statistics
        total_dimensions_beyond_range = torch.sum(beyond_range).item()
        total_possible_dimensions = len(attack_msgs) * attack_msgs.shape[1]
        proportion_dimensions_beyond_range = (
            total_dimensions_beyond_range / total_possible_dimensions
        ) * 100

        return {
            "status": "success",
            "normal_range_min": normal_min.tolist(),
            "normal_range_max": normal_max.tolist(),
            "attack_edges_beyond_range": attack_edges_beyond_range,
            "attack_edges_beyond_range_pct": attack_edges_beyond_range_pct,
            "avg_dimensions_beyond_range": avg_dimensions_beyond_range,
            "total_dimensions_beyond_range": total_dimensions_beyond_range,
            "proportion_dimensions_beyond_range": proportion_dimensions_beyond_range,
            "max_deviation_extent": max_deviation_extent,
            "avg_deviation_extent": avg_deviation_extent,
            "deviation_extents_per_edge": deviation_extents,
        }

    def _compute_anomaly_scores(
        self, messages: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """Compute multiple anomaly scores for messages."""
        if not self.is_fitted or len(messages) == 0:
            return {}

        scores = {}

        # 1. Z-score based detection
        z_scores = torch.abs((messages - self.mean) / (self.std + 1e-8))
        scores["z_score_max"] = torch.max(z_scores, dim=1)[0]
        scores["z_score_mean"] = torch.mean(z_scores, dim=1)

        # 2. Mahalanobis distance
        try:
            diff = messages - self.mean
            inv_cov = torch.linalg.pinv(self.cov_matrix)
            mahal_dist = torch.sqrt(torch.sum(diff @ inv_cov * diff, dim=1))
            scores["mahalanobis"] = mahal_dist
        except Exception:
            scores["mahalanobis"] = torch.zeros(len(messages))

        # 3. Cosine similarity anomaly
        cosine_sim = F.cosine_similarity(messages, self.mean.unsqueeze(0))
        scores["cosine_anomaly"] = 1 - cosine_sim

        # 4. L2 distance to mean
        scores["l2_distance"] = torch.norm(messages - self.mean, dim=1)

        return scores

    def _get_anomaly_flags(self, scores: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Get binary anomaly flags based on threshold."""
        if "z_score_max" not in scores:
            return torch.zeros(0, dtype=torch.bool)

        # Use z-score as primary indicator
        return scores["z_score_max"] > self.feature_threshold

    def _combine_detection_scores(
        self, results: Dict, attack_timestamps: List[int] = None
    ) -> Dict:
        """Combine structural and feature detection scores with equal weights (1/2 each)."""
        structural = results.get("structural_detection", {})
        feature = results.get("feature_detection", {})

        # Evaluate ALL methods against ground truth attack timestamps
        structural_score = 0.0
        feature_score = 0.0

        structural_precision = 0.0
        structural_recall = 0.0
        structural_f1 = 0.0

        feature_precision = 0.0
        feature_recall = 0.0
        feature_f1 = 0.0

        if attack_timestamps:
            attack_set = set(attack_timestamps)

            # Evaluate structural detection against ground truth
            if structural.get("outliers"):
                structural_outliers = set(structural.get("outliers", []))
                tp = len(attack_set.intersection(structural_outliers))
                fp = len(structural_outliers - attack_set)
                fn = len(attack_set - structural_outliers)

                structural_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                structural_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                structural_f1 = (
                    2
                    * structural_precision
                    * structural_recall
                    / (structural_precision + structural_recall)
                    if (structural_precision + structural_recall) > 0
                    else 0.0
                )
                structural_score = structural_f1

            # Evaluate feature detection against ground truth
            if feature.get("outliers"):
                feature_outliers = set(feature.get("outliers", []))
                tp = len(attack_set.intersection(feature_outliers))
                fp = len(feature_outliers - attack_set)
                fn = len(attack_set - feature_outliers)

                feature_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                feature_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                feature_f1 = (
                    2
                    * feature_precision
                    * feature_recall
                    / (feature_precision + feature_recall)
                    if (feature_precision + feature_recall) > 0
                    else 0.0
                )
                feature_score = feature_f1

        # Equal weight ensemble (1/2 each) - all scores now represent ground truth detection performance
        equal_weight_ensemble = (structural_score + feature_score) / 2

        # Previous weighted ensemble (for comparison)
        weighted_score = 0.7 * structural_score + 0.3 * feature_score

        combined_analysis = {
            "structural_detection_rate": structural_score,
            "structural_precision": structural_precision,
            "structural_recall": structural_recall,
            "structural_f1_score": structural_f1,
            "feature_anomaly_rate": feature_score,
            "feature_precision": feature_precision,
            "feature_recall": feature_recall,
            "feature_f1_score": feature_f1,
            "equal_weight_ensemble_score": equal_weight_ensemble,
            "weighted_ensemble_score": weighted_score,
            "detection_confidence": self._compute_confidence(
                structural_score, feature_score
            ),
            "recommendation": self._get_recommendation(structural_score, feature_score),
        }

        logging.info(
            f"Combined Detection (vs Ground Truth) - Structural F1: {structural_f1:.3f}, Feature F1: {feature_f1:.3f}, Equal Weight Ensemble: {equal_weight_ensemble:.3f}"
        )

        return combined_analysis

    def _convert_numpy_types(self, obj):
        """Convert numpy types to Python native types for JSON serialization."""
        import numpy as np

        if isinstance(obj, dict):
            return {key: self._convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_numpy_types(item) for item in obj]
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj

    def _compute_confidence(self, structural_score: float, feature_score: float) -> str:
        """Compute confidence level based on agreement between detectors."""
        scores = [structural_score, feature_score]
        avg_score = sum(scores) / len(scores)

        # Check agreement between detectors
        max_diff = max(scores) - min(scores)

        if max_diff < 0.2:  # High agreement
            if avg_score > 0.7:
                return "high_confidence_attack"
            elif avg_score < 0.3:
                return "high_confidence_normal"
            else:
                return "medium_confidence"
        elif max_diff < 0.4:  # Medium agreement
            return "medium_confidence_disagreement"
        else:  # Low agreement
            return "low_confidence_disagreement"

    def _get_recommendation(self, structural_score: float, feature_score: float) -> str:
        """Get recommendation based on detection scores."""
        avg_score = (structural_score + feature_score) / 2

        if avg_score > 0.8:
            return "strong_attack_detected"
        elif avg_score > 0.6:
            return "likely_attack_detected"
        elif avg_score > 0.4:
            return "possible_attack_detected"
        elif avg_score > 0.2:
            return "weak_attack_signal"
        else:
            return "no_attack_detected"

    def _export_results(self, edge_tracker: EdgeTracker, results: Dict) -> str:
        """Export combined anomaly detection results."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{edge_tracker.dataset_name}_{edge_tracker.attack_type}_unified_anomaly_{timestamp}.json"
        filepath = os.path.join(edge_tracker.output_dir, filename)

        # Convert numpy types to Python native types for JSON serialization
        serializable_results = self._convert_numpy_types(results)

        with open(filepath, "w") as f:
            json.dump(serializable_results, f, indent=2)

        logging.info(f"Exported unified anomaly detection results to {filepath}")
        return filepath

    def plot_anomaly_detection_results(
        self, edge_tracker: EdgeTracker, results: Dict, save_plot: bool = True
    ) -> str:
        """
        Plot anomaly detection results showing edges detected as anomalous per batch.

        Args:
            edge_tracker: EdgeTracker instance with collected data
            results: Results from anomaly detection analysis
            save_plot: Whether to save the plot to file

        Returns:
            Path to saved plot file (if save_plot=True) or empty string
        """
        try:
            # Get batch statistics
            batch_stats = edge_tracker.get_batch_statistics()

            if not batch_stats:
                logging.warning("No batch statistics available for plotting")
                return ""

            # Prepare data for plotting
            batch_ids = sorted(batch_stats.keys())
            total_edges = [
                batch_stats[batch_id]["total_edges"] for batch_id in batch_ids
            ]
            normal_edges = [
                batch_stats[batch_id]["normal_edges"] for batch_id in batch_ids
            ]
            attack_edges = [
                batch_stats[batch_id]["attack_edges"] for batch_id in batch_ids
            ]

            # Get SPOTLIGHT anomaly detection results (batches detected as anomalous)
            spotlight_outliers = set()

            if (
                "structural_detection" in results
                and "detected_outliers" in results["structural_detection"]
            ):
                spotlight_outliers = set(
                    results["structural_detection"]["detected_outliers"]
                )

            # Count which batches SPOTLIGHT detected as anomalous
            spotlight_detected = [
                1 if batch_id in spotlight_outliers else 0 for batch_id in batch_ids
            ]

            # Get ground truth attack batch IDs
            attack_batch_ids = edge_tracker.get_attack_batch_ids()
            ground_truth_attacks = [
                1 if batch_id in attack_batch_ids else 0 for batch_id in batch_ids
            ]

            # Create the plot - 3 plots including anomaly scores
            fig, axes = plt.subplots(1, 3, figsize=(20, 6))
            fig.suptitle(
                f"Anomaly Detection Results - {edge_tracker.dataset_name}", fontsize=16
            )

            # Plot 1: Edge counts per batch
            ax1 = axes[0]
            width = 0.35
            x = np.arange(len(batch_ids))

            ax1.bar(
                x - width / 2,
                normal_edges,
                width,
                label="Normal Edges",
                color="skyblue",
                alpha=0.8,
            )
            ax1.bar(
                x + width / 2,
                attack_edges,
                width,
                label="Attack Edges",
                color="red",
                alpha=0.8,
            )

            ax1.set_xlabel("Batch ID")
            ax1.set_ylabel("Number of Edges")
            ax1.set_title("Edge Distribution per Batch")
            ax1.set_xticks(x)
            ax1.set_xticklabels(batch_ids)
            ax1.legend()
            ax1.grid(True, alpha=0.3)

            # Plot 2: SPOTLIGHT detection vs Ground Truth (improved visualization)
            ax2 = axes[1]

            # Create a combined visualization with different markers and colors
            batch_positions = np.arange(len(batch_ids))

            # Plot ground truth attacks as red circles
            gt_positions = [i for i, gt in enumerate(ground_truth_attacks) if gt == 1]
            ax2.scatter(
                [batch_positions[i] for i in gt_positions],
                [1] * len(gt_positions),
                color="red",
                s=100,
                marker="o",
                label="Ground Truth Attacks",
                alpha=0.8,
                zorder=3,
            )

            # Plot SPOTLIGHT detected as blue squares
            sl_positions = [i for i, sl in enumerate(spotlight_detected) if sl == 1]
            ax2.scatter(
                [batch_positions[i] for i in sl_positions],
                [0.5] * len(sl_positions),
                color="blue",
                s=100,
                marker="s",
                label="SPOTLIGHT Detected",
                alpha=0.8,
                zorder=3,
            )

            # Plot true positives as green diamonds (overlap)
            tp_positions = [
                i
                for i in range(len(batch_ids))
                if spotlight_detected[i] == 1 and ground_truth_attacks[i] == 1
            ]
            ax2.scatter(
                [batch_positions[i] for i in tp_positions],
                [0.75] * len(tp_positions),
                color="green",
                s=120,
                marker="D",
                label="True Positives",
                alpha=0.9,
                zorder=4,
            )

            # Add vertical lines for each batch
            for i, batch_id in enumerate(batch_ids):
                ax2.axvline(x=i, color="lightgray", alpha=0.3, linewidth=0.5)

            ax2.set_xlabel("Batch ID")
            ax2.set_ylabel("Detection Status")
            ax2.set_title("SPOTLIGHT Detection vs Ground Truth Attacks")
            ax2.set_xticks(batch_positions)
            ax2.set_xticklabels(batch_ids)
            ax2.set_ylim(-0.1, 1.3)
            ax2.set_yticks([0.5, 0.75, 1.0])
            ax2.set_yticklabels(["SPOTLIGHT", "True Positives", "Ground Truth"])
            ax2.legend()
            ax2.grid(True, alpha=0.3)

            # Plot 3: SPOTLIGHT Anomaly Scores over time
            ax3 = axes[2]

            # Get SPOTLIGHT anomaly scores if available
            spotlight_scores = []
            if (
                "structural_detection" in results
                and "anomaly_scores" in results["structural_detection"]
            ):
                spotlight_scores = results["structural_detection"]["anomaly_scores"]
            else:
                # If scores not available, create dummy scores for visualization
                spotlight_scores = [0.0] * len(batch_ids)
                logging.warning(
                    "SPOTLIGHT anomaly scores not available, using dummy scores"
                )

            # Plot anomaly scores as a line
            ax3.plot(
                batch_ids,
                spotlight_scores,
                color="blue",
                linewidth=2,
                marker="o",
                markersize=4,
                alpha=0.8,
                label="SPOTLIGHT Scores",
            )

            # Highlight ground truth attack batches
            gt_batch_positions = [
                batch_ids[i] for i, gt in enumerate(ground_truth_attacks) if gt == 1
            ]
            gt_scores = [
                spotlight_scores[i]
                for i, gt in enumerate(ground_truth_attacks)
                if gt == 1
            ]

            if gt_batch_positions:
                ax3.scatter(
                    gt_batch_positions,
                    gt_scores,
                    color="red",
                    s=100,
                    marker="*",
                    label="Ground Truth Attacks",
                    alpha=0.9,
                    zorder=3,
                )

            # Add threshold line if available
            if (
                "structural_detection" in results
                and "threshold" in results["structural_detection"]
            ):
                threshold = results["structural_detection"]["threshold"]
                ax3.axhline(
                    y=threshold,
                    color="gray",
                    linestyle="--",
                    alpha=0.7,
                    label=f"Threshold ({threshold:.3f})",
                )

            ax3.set_xlabel("Batch ID")
            ax3.set_ylabel("Anomaly Score")
            ax3.set_title("SPOTLIGHT Anomaly Scores Over Time")
            ax3.grid(True, alpha=0.3)
            ax3.legend()

            # Set y-axis range to 5 points above max score
            max_score = max(spotlight_scores) if spotlight_scores else 0
            ax3.set_ylim(0, max_score + 5)

            # Add summary text
            true_positives = sum(
                1
                for i, batch_id in enumerate(batch_ids)
                if spotlight_detected[i] == 1 and ground_truth_attacks[i] == 1
            )
            false_positives = sum(
                1
                for i, batch_id in enumerate(batch_ids)
                if spotlight_detected[i] == 1 and ground_truth_attacks[i] == 0
            )
            false_negatives = sum(
                1
                for i, batch_id in enumerate(batch_ids)
                if spotlight_detected[i] == 0 and ground_truth_attacks[i] == 1
            )

            precision = (
                true_positives / (true_positives + false_positives)
                if (true_positives + false_positives) > 0
                else 0
            )
            recall = (
                true_positives / (true_positives + false_negatives)
                if (true_positives + false_negatives) > 0
                else 0
            )

            # Add feature range analysis to summary
            range_info = ""
            if (
                "feature_detection" in results
                and "range_analysis" in results["feature_detection"]
                and results["feature_detection"]["range_analysis"]["status"]
                == "success"
            ):
                range_analysis = results["feature_detection"]["range_analysis"]
                attack_count = results["feature_detection"].get(
                    "attack_message_count", "N/A"
                )
                range_info = f"\nFeature Range Analysis:\nAttack edges beyond range: {range_analysis['attack_edges_beyond_range']}/{attack_count} ({range_analysis['attack_edges_beyond_range_pct']:.1f}%)\nAvg dimensions beyond range: {range_analysis['avg_dimensions_beyond_range']:.2f}\nMax deviation extent: {range_analysis['max_deviation_extent']:.3f}"

            summary_text = f"SPOTLIGHT Performance:\nPrecision: {precision:.3f}\nRecall: {recall:.3f}\nTP: {true_positives}, FP: {false_positives}, FN: {false_negatives}{range_info}"

            fig.text(
                0.02,
                0.02,
                summary_text,
                fontsize=10,
                bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8),
            )

            plt.tight_layout()

            if save_plot:
                # Save plot
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                plot_filename = f"{edge_tracker.dataset_name}_{edge_tracker.attack_type}_anomaly_plot_{timestamp}.png"
                plot_filepath = os.path.join(edge_tracker.output_dir, plot_filename)

                plt.savefig(plot_filepath, dpi=300, bbox_inches="tight")
                logging.info(f"Anomaly detection plot saved to {plot_filepath}")

                plt.close()
                return plot_filepath
            else:
                plt.show()
                return ""

        except Exception as e:
            logging.error(f"Failed to create anomaly detection plot: {e}")
            import traceback

            traceback.print_exc()
            return ""
