#!/usr/bin/env python3
"""
Realistic Attack Pattern Generator for Network Intrusion Detection.
Generates sophisticated attack patterns that mimic real-world network intrusions
instead of simple degree-based patterns.
"""

import numpy as np
import pandas as pd
import networkx as nx
import torch
from typing import List, Dict, Tuple, Optional, Any, Set
from dataclasses import dataclass
from enum import Enum
import random
from collections import defaultdict

from ..utils.common import GraphData, get_logger


class AttackType(Enum):
    """Enumeration of attack types based on NSL-KDD and real-world patterns."""
    NORMAL = "normal"
    DOS = "dos"                    # Denial of Service
    PROBE = "probe"                # Reconnaissance/Scanning
    R2L = "r2l"                   # Remote to Local
    U2R = "u2r"                   # User to Root
    BOTNET = "botnet"             # Botnet communication
    LATERAL_MOVEMENT = "lateral"   # Lateral movement
    DATA_EXFILTRATION = "exfil"   # Data exfiltration
    COMMAND_CONTROL = "c2"        # Command & Control


@dataclass
class AttackPattern:
    """Represents a specific attack pattern with its characteristics."""
    attack_type: AttackType
    name: str
    description: str
    target_selection: str         # How targets are selected
    traffic_signature: Dict       # Traffic characteristics
    temporal_pattern: str         # Timing pattern
    network_footprint: Dict       # Network behavior
    detection_difficulty: float   # 0.0 (easy) to 1.0 (hard)


@dataclass
class AttackInstance:
    """Represents a specific instance of an attack."""
    pattern: AttackPattern
    source_nodes: List[int]
    target_nodes: List[int]
    start_time: float
    duration: float
    intensity: float
    metadata: Dict[str, Any]


class AttackPatternGenerator:
    """
    Generates realistic attack patterns for network graphs.
    Based on actual attack behaviors observed in real networks.
    """
    
    def __init__(self, seed: int = 42):
        self.seed = seed
        self.rng = np.random.RandomState(seed)
        self.logger = get_logger("attack_pattern_generator")
        
        # Initialize attack patterns
        self.attack_patterns = self._initialize_attack_patterns()
        
    def _initialize_attack_patterns(self) -> Dict[AttackType, AttackPattern]:
        """Initialize realistic attack patterns based on cybersecurity research."""
        
        patterns = {}
        
        # DoS/DDoS Attacks
        patterns[AttackType.DOS] = AttackPattern(
            attack_type=AttackType.DOS,
            name="Distributed Denial of Service",
            description="Overwhelm target with traffic from multiple sources",
            target_selection="high_degree_servers",
            traffic_signature={
                "high_packet_rate": True,
                "small_packet_size": True,
                "short_connections": True,
                "syn_flood": True,
                "bandwidth_consumption": "high"
            },
            temporal_pattern="burst",
            network_footprint={
                "many_to_one": True,
                "connection_failures": "high",
                "response_time_degradation": True
            },
            detection_difficulty=0.3  # Relatively easy to detect
        )
        
        # Reconnaissance/Probe Attacks
        patterns[AttackType.PROBE] = AttackPattern(
            attack_type=AttackType.PROBE,
            name="Network Reconnaissance",
            description="Scan network to discover services and vulnerabilities",
            target_selection="systematic_sweep",
            traffic_signature={
                "port_scanning": True,
                "service_enumeration": True,
                "failed_connections": "high",
                "diverse_ports": True,
                "small_payloads": True
            },
            temporal_pattern="systematic",
            network_footprint={
                "one_to_many": True,
                "sequential_targets": True,
                "connection_attempts": "many"
            },
            detection_difficulty=0.4
        )
        
        # Remote to Local (R2L) Attacks
        patterns[AttackType.R2L] = AttackPattern(
            attack_type=AttackType.R2L,
            name="Remote Access Intrusion",
            description="Gain unauthorized access from external network",
            target_selection="vulnerable_services",
            traffic_signature={
                "brute_force_attempts": True,
                "authentication_failures": "high",
                "exploit_payloads": True,
                "encrypted_channels": True
            },
            temporal_pattern="persistent",
            network_footprint={
                "external_to_internal": True,
                "service_specific": True,
                "escalation_attempts": True
            },
            detection_difficulty=0.6
        )
        
        # User to Root (U2R) Attacks
        patterns[AttackType.U2R] = AttackPattern(
            attack_type=AttackType.U2R,
            name="Privilege Escalation",
            description="Escalate privileges from normal user to administrator",
            target_selection="compromised_internal",
            traffic_signature={
                "system_calls": "unusual",
                "file_access_patterns": "suspicious",
                "process_creation": "abnormal",
                "registry_modifications": True
            },
            temporal_pattern="stealthy",
            network_footprint={
                "internal_lateral": True,
                "administrative_actions": True,
                "credential_harvesting": True
            },
            detection_difficulty=0.8
        )
        
        # Botnet Communication
        patterns[AttackType.BOTNET] = AttackPattern(
            attack_type=AttackType.BOTNET,
            name="Botnet Communication",
            description="Infected hosts communicating with command & control",
            target_selection="random_internal",
            traffic_signature={
                "periodic_beacons": True,
                "dns_queries": "suspicious",
                "encrypted_traffic": True,
                "peer_to_peer": True
            },
            temporal_pattern="periodic",
            network_footprint={
                "outbound_connections": "regular",
                "domain_generation": True,
                "traffic_tunneling": True
            },
            detection_difficulty=0.7
        )
        
        # Lateral Movement
        patterns[AttackType.LATERAL_MOVEMENT] = AttackPattern(
            attack_type=AttackType.LATERAL_MOVEMENT,
            name="Lateral Movement",
            description="Move through network after initial compromise",
            target_selection="network_neighbors",
            traffic_signature={
                "credential_reuse": True,
                "smb_traffic": "unusual",
                "rdp_connections": "suspicious",
                "wmi_usage": "abnormal"
            },
            temporal_pattern="exploratory",
            network_footprint={
                "internal_scanning": True,
                "service_enumeration": True,
                "file_sharing_abuse": True
            },
            detection_difficulty=0.9
        )
        
        # Data Exfiltration
        patterns[AttackType.DATA_EXFILTRATION] = AttackPattern(
            attack_type=AttackType.DATA_EXFILTRATION,
            name="Data Exfiltration",
            description="Steal sensitive data from the network",
            target_selection="data_repositories",
            traffic_signature={
                "large_uploads": True,
                "unusual_protocols": True,
                "compression": True,
                "encryption": True,
                "off_hours_activity": True
            },
            temporal_pattern="scheduled",
            network_footprint={
                "internal_to_external": True,
                "database_access": "unusual",
                "file_transfers": "large"
            },
            detection_difficulty=0.8
        )
        
        return patterns

    def _modify_node_features_for_attack(self, 
                                       original_features: torch.Tensor, 
                                       attack_type: AttackType, 
                                       role: str) -> torch.Tensor:
        """
        Modify node features to reflect attack behavior.
        Safely handles different feature dimensions by checking bounds.
        """
        
        # Create a copy to avoid modifying the original
        modified_features = original_features.clone()
        num_features = len(modified_features)
        
        # Helper function to safely modify features
        def safe_modify(idx, operation, *args):
            if idx < num_features:
                if operation == "multiply":
                    modified_features[idx] *= args[0]
                elif operation == "add":
                    modified_features[idx] += args[0]
                elif operation == "set_min":
                    modified_features[idx] = min(modified_features[idx] + args[0], args[1])
                elif operation == "set_max":
                    modified_features[idx] = max(modified_features[idx], args[0])
        
        # Apply VERY SUBTLE attack-specific modifications (challenging but realistic)
        if attack_type == AttackType.DOS:
            # DoS attacks: Very subtle traffic patterns (advanced/distributed attacks)
            safe_modify(10, "multiply", 0.85) # Slightly reduced incoming traffic
            safe_modify(11, "multiply", 1.4)  # Moderately increased outgoing traffic
            safe_modify(15, "multiply", 1.8)  # Slightly higher connection failure rate
            safe_modify(17, "multiply", 1.3)  # Few more short-lived connections
            # Subtle topology changes
            safe_modify(1, "multiply", 1.3)   # Slightly higher out-degree
            safe_modify(2, "multiply", 1.3)   # Slightly higher total degree
            if num_features >= 7:
                safe_modify(6, "multiply", 1.2)  # Slightly more port diversity
            
        elif attack_type == AttackType.PROBE:
            # Reconnaissance: Very subtle scanning (advanced persistent threats)
            safe_modify(15, "multiply", 1.6)  # Slightly higher failed connection ratio
            safe_modify(9, "multiply", 1.3)   # Slightly more ICMP traffic
            safe_modify(1, "multiply", 1.5)   # Moderately higher out-degree
            safe_modify(2, "multiply", 1.5)   # Moderately higher total degree
            if num_features >= 7:
                safe_modify(6, "multiply", 1.3)  # Slightly more port diversity
            
        elif attack_type == AttackType.R2L:
            # Remote to Local: Very subtle authentication patterns
            safe_modify(13, "multiply", 1.4)  # Slightly more off-hours activity
            safe_modify(15, "multiply", 1.6)  # Few failed connections
            safe_modify(0, "multiply", 1.2)   # Slightly higher in-degree
            if num_features >= 19:  # Security features exist
                safe_modify(19, "add", 1.5)    # Few authentication failures
                
        elif attack_type == AttackType.U2R:
            # User to Root: Very subtle privilege escalation
            safe_modify(13, "multiply", 1.3)  # Slightly more off-hours activity
            safe_modify(4, "multiply", 1.4)   # Moderately higher betweenness centrality
            if num_features >= 20:  # Security features exist
                safe_modify(20, "add", 0.5)    # Very few privilege escalations
                safe_modify(19, "add", 1.0)    # Few authentication failures
                
        elif attack_type == AttackType.LATERAL_MOVEMENT:
            # Lateral movement: Very subtle new connections
            safe_modify(7, "multiply", 1.2)   # Slightly more TCP traffic
            safe_modify(2, "multiply", 1.3)   # Slightly higher total degree
            safe_modify(5, "multiply", 1.1)   # Slightly higher closeness
            if num_features >= 22:  # Security features exist
                safe_modify(22, "add", 2.0)    # Few new unique destinations
                safe_modify(19, "add", 0.5)    # Very few authentication failures
                
        elif attack_type == AttackType.DATA_EXFILTRATION:
            # Data exfiltration: Subtle outbound traffic increase
            safe_modify(11, "multiply", 1.8)  # Moderately higher outbound traffic
            safe_modify(13, "multiply", 1.5)  # Slightly more off-hours activity
            safe_modify(1, "multiply", 1.3)   # Slightly higher out-degree
            if num_features >= 24:  # Security features exist
                safe_modify(24, "multiply", 2.0)  # Moderately higher upload/download ratio
                safe_modify(25, "add", 1.0)       # Few large file transfers
                
        elif attack_type == AttackType.BOTNET:
            # Botnet: Very subtle consistent communication patterns
            safe_modify(13, "set_max", 0.55)  # Slightly consistent off-hours activity
            safe_modify(2, "multiply", 1.2)   # Slightly higher total degree
            if num_features >= 21:  # Security features exist
                safe_modify(21, "add", 1.0)    # Few DNS queries (C2 communication)
                safe_modify(23, "add", 0.5)    # Very few burst events
        
        # Role-specific modifications
        if role == "target" and num_features > 15:
            safe_modify(15, "multiply", 1.2)  # Slightly higher failure rate
            
        # Ensure features stay within reasonable bounds
        modified_features = torch.clamp(modified_features, min=0.0, max=100.0)
        
        return modified_features

    def generate_attack_scenario(self, 
                               graph_data: GraphData,
                               attack_types: List[AttackType] = None,
                               attack_ratio: float = 0.15,
                               multi_stage: bool = True) -> Tuple[GraphData, List[AttackInstance]]:
        """
        Generate a realistic attack scenario on the given network graph.
        
        Args:
            graph_data: Network graph to inject attacks into
            attack_types: Types of attacks to generate (None = random selection)
            attack_ratio: Proportion of nodes to be involved in attacks
            multi_stage: Whether to generate multi-stage attack campaigns
            
        Returns:
            Modified graph with attack labels and list of attack instances
        """
        
        if attack_types is None:
            # Select realistic attack type distribution
            attack_types = self._select_realistic_attack_mix()
        
        num_nodes = graph_data.x.shape[0]
        num_attack_nodes = max(1, int(attack_ratio * num_nodes))
        
        # Initialize attack instances
        attack_instances = []
        
        # Create node attack labels (start with all normal)
        node_labels = torch.zeros(num_nodes, dtype=torch.long)
        
        # Create attack metadata for each node
        node_attack_metadata = [{"attack_types": [], "is_compromised": False} for _ in range(num_nodes)]
        
        if multi_stage:
            # Generate multi-stage attack campaign
            attack_instances = self._generate_multi_stage_campaign(
                graph_data, attack_types, num_attack_nodes
            )
        else:
            # Generate independent attacks
            attack_instances = self._generate_independent_attacks(
                graph_data, attack_types, num_attack_nodes
            )
        
        # Apply attack instances to graph
        compromised_nodes = set()
        
        # Create a copy of the node features to modify
        modified_features = graph_data.x.clone()
        
        for attack in attack_instances:
            # Mark source nodes as compromised
            for node_id in attack.source_nodes:
                if node_id < num_nodes:
                    node_labels[node_id] = 1
                    node_attack_metadata[node_id]["is_compromised"] = True
                    node_attack_metadata[node_id]["attack_types"].append(attack.pattern.attack_type)
                    compromised_nodes.add(node_id)
                    
                    # Modify node features to reflect attack behavior
                    modified_features[node_id] = self._modify_node_features_for_attack(
                        modified_features[node_id], attack.pattern.attack_type, "source"
                    )
            
            # Mark target nodes if they become compromised
            if attack.pattern.attack_type in [AttackType.R2L, AttackType.U2R, AttackType.LATERAL_MOVEMENT]:
                for node_id in attack.target_nodes:
                    if node_id < num_nodes:
                        node_labels[node_id] = 1
                        node_attack_metadata[node_id]["is_compromised"] = True
                        node_attack_metadata[node_id]["attack_types"].append(attack.pattern.attack_type)
                        compromised_nodes.add(node_id)
                        
                        # Modify node features to reflect compromise
                        modified_features[node_id] = self._modify_node_features_for_attack(
                            modified_features[node_id], attack.pattern.attack_type, "target"
                        )
        
        # Ensure we meet the target attack ratio by adding more compromised nodes if needed
        current_attack_count = len(compromised_nodes)
        if current_attack_count < num_attack_nodes:
            additional_needed = num_attack_nodes - current_attack_count
            # Select additional nodes to compromise (avoid already compromised ones)
            available_nodes = [i for i in range(num_nodes) if i not in compromised_nodes]
            if available_nodes:
                additional_compromised = self.rng.choice(
                    available_nodes, 
                    size=min(additional_needed, len(available_nodes)), 
                    replace=False
                )
                for node_id in additional_compromised:
                    node_labels[node_id] = 1
                    node_attack_metadata[node_id]["is_compromised"] = True
                    node_attack_metadata[node_id]["attack_types"].append(AttackType.DOS)  # Default attack type
                    compromised_nodes.add(node_id)
                    
                    # Modify features for additional compromised nodes
                    modified_features[node_id] = self._modify_node_features_for_attack(
                        modified_features[node_id], AttackType.DOS, "source"
                    )
        
        # Create modified graph data with updated features
        modified_graph = GraphData(
            x=modified_features,  # Use modified features instead of original
            edge_index=graph_data.edge_index,
            edge_attr=graph_data.edge_attr,
            y_node=node_labels,
            graph_id=f"{graph_data.graph_id}_with_attacks",
            window_idx=graph_data.window_idx
        )
        
        # Add attack metadata
        modified_graph.attack_metadata = {
            "attack_instances": attack_instances,
            "node_metadata": node_attack_metadata,
            "attack_ratio": float(node_labels.float().mean().item()),
            "attack_types_present": list(set([a.pattern.attack_type for a in attack_instances]))
        }
        
        self.logger.info(f"Generated {len(attack_instances)} attack instances")
        self.logger.info(f"Attack ratio: {modified_graph.attack_metadata['attack_ratio']:.3f}")
        
        return modified_graph, attack_instances

    def _select_realistic_attack_mix(self) -> List[AttackType]:
        """Select a realistic mix of attack types based on threat landscape."""
        
        # Probability distribution based on real-world attack frequency
        attack_probabilities = {
            AttackType.DOS: 0.25,           # Very common
            AttackType.PROBE: 0.30,         # Most common
            AttackType.R2L: 0.20,          # Common
            AttackType.U2R: 0.05,          # Less common but high impact
            AttackType.BOTNET: 0.10,       # Increasingly common
            AttackType.LATERAL_MOVEMENT: 0.05,  # Advanced persistent threats
            AttackType.DATA_EXFILTRATION: 0.05,  # Targeted attacks (adjusted to sum to 1.0)
        }
        
        # Select 2-4 attack types for this scenario
        num_attack_types = self.rng.randint(2, 5)
        
        attack_types = self.rng.choice(
            list(attack_probabilities.keys()),
            size=num_attack_types,
            replace=False,
            p=list(attack_probabilities.values())
        )
        
        return attack_types.tolist()

    def _generate_multi_stage_campaign(self, 
                                     graph_data: GraphData, 
                                     attack_types: List[AttackType],
                                     num_attack_nodes: int) -> List[AttackInstance]:
        """Generate a realistic multi-stage attack campaign."""
        
        attack_instances = []
        compromised_nodes = set()
        
        # Stage 1: Initial Reconnaissance
        if AttackType.PROBE in attack_types:
            probe_attack = self._generate_probe_attack(graph_data)
            attack_instances.append(probe_attack)
            
        # Stage 2: Initial Compromise
        if AttackType.R2L in attack_types:
            initial_targets = self._select_vulnerable_targets(graph_data, max_targets=2)
            r2l_attack = self._generate_r2l_attack(graph_data, initial_targets)
            attack_instances.append(r2l_attack)
            compromised_nodes.update(r2l_attack.target_nodes)
            
        # Stage 3: Privilege Escalation
        if AttackType.U2R in attack_types and compromised_nodes:
            escalation_sources = list(compromised_nodes)[:2]
            u2r_attack = self._generate_u2r_attack(graph_data, escalation_sources)
            attack_instances.append(u2r_attack)
            compromised_nodes.update(u2r_attack.target_nodes)
            
        # Stage 4: Lateral Movement
        if AttackType.LATERAL_MOVEMENT in attack_types and compromised_nodes:
            lateral_sources = list(compromised_nodes)
            lateral_attack = self._generate_lateral_movement(graph_data, lateral_sources, num_attack_nodes)
            attack_instances.append(lateral_attack)
            compromised_nodes.update(lateral_attack.target_nodes)
            
        # Stage 5: Persistence and C2
        if AttackType.BOTNET in attack_types and compromised_nodes:
            botnet_nodes = list(compromised_nodes)
            botnet_attack = self._generate_botnet_communication(graph_data, botnet_nodes)
            attack_instances.append(botnet_attack)
            
        # Stage 6: Data Exfiltration
        if AttackType.DATA_EXFILTRATION in attack_types and compromised_nodes:
            exfil_sources = list(compromised_nodes)[:3]  # Use subset for exfiltration
            exfil_attack = self._generate_data_exfiltration(graph_data, exfil_sources)
            attack_instances.append(exfil_attack)
            
        # Stage 7: DoS (could be diversionary or final stage)
        if AttackType.DOS in attack_types:
            dos_attack = self._generate_dos_attack(graph_data, list(compromised_nodes))
            attack_instances.append(dos_attack)
        
        return attack_instances

    def _generate_independent_attacks(self, 
                                    graph_data: GraphData,
                                    attack_types: List[AttackType],
                                    num_attack_nodes: int) -> List[AttackInstance]:
        """Generate independent, unrelated attacks."""
        
        attack_instances = []
        
        for attack_type in attack_types:
            if attack_type == AttackType.DOS:
                attack = self._generate_dos_attack(graph_data)
            elif attack_type == AttackType.PROBE:
                attack = self._generate_probe_attack(graph_data)
            elif attack_type == AttackType.R2L:
                targets = self._select_vulnerable_targets(graph_data, max_targets=3)
                attack = self._generate_r2l_attack(graph_data, targets)
            elif attack_type == AttackType.U2R:
                sources = self._select_random_internal_nodes(graph_data, max_nodes=2)
                attack = self._generate_u2r_attack(graph_data, sources)
            elif attack_type == AttackType.BOTNET:
                bot_nodes = self._select_random_internal_nodes(graph_data, max_nodes=5)
                attack = self._generate_botnet_communication(graph_data, bot_nodes)
            elif attack_type == AttackType.LATERAL_MOVEMENT:
                sources = self._select_random_internal_nodes(graph_data, max_nodes=2)
                attack = self._generate_lateral_movement(graph_data, sources, num_attack_nodes)
            elif attack_type == AttackType.DATA_EXFILTRATION:
                sources = self._select_random_internal_nodes(graph_data, max_nodes=3)
                attack = self._generate_data_exfiltration(graph_data, sources)
            else:
                continue
                
            attack_instances.append(attack)
        
        return attack_instances

    def _generate_dos_attack(self, graph_data: GraphData, compromised_nodes: List[int] = None) -> AttackInstance:
        """Generate a DoS attack pattern."""
        
        # Select high-degree nodes as targets (servers)
        degrees = self._compute_node_degrees(graph_data)
        high_degree_nodes = np.argsort(degrees)[-3:]  # Top 3 highest degree nodes
        target_nodes = [int(node) for node in high_degree_nodes]
        
        # Select source nodes (attackers)
        if compromised_nodes:
            source_nodes = compromised_nodes[:10]  # Use existing compromised nodes
        else:
            # Select random nodes as attack sources
            num_sources = min(10, graph_data.x.shape[0] // 4)
            source_nodes = self.rng.choice(
                graph_data.x.shape[0], 
                size=num_sources, 
                replace=False
            ).tolist()
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.DOS],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(0, 10),
            duration=self.rng.uniform(30, 300),  # 30 seconds to 5 minutes
            intensity=self.rng.uniform(0.7, 1.0),  # High intensity
            metadata={
                "attack_vector": "tcp_syn_flood",
                "packet_rate": self.rng.uniform(1000, 10000),
                "target_services": ["http", "https", "dns"]
            }
        )

    def _generate_probe_attack(self, graph_data: GraphData) -> AttackInstance:
        """Generate a network reconnaissance attack."""
        
        # Select 1-2 nodes as scanners (usually external)
        num_scanners = min(2, max(1, graph_data.x.shape[0] // 20))
        source_nodes = self.rng.choice(
            graph_data.x.shape[0], 
            size=num_scanners, 
            replace=False
        ).tolist()
        
        # Target many nodes (systematic scan)
        num_targets = min(graph_data.x.shape[0] // 2, 50)  # Scan up to half the network
        target_nodes = self.rng.choice(
            graph_data.x.shape[0], 
            size=num_targets, 
            replace=False
        ).tolist()
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.PROBE],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(0, 5),
            duration=self.rng.uniform(60, 600),  # 1-10 minutes
            intensity=self.rng.uniform(0.3, 0.7),
            metadata={
                "scan_type": self.rng.choice(["port_scan", "service_scan", "os_fingerprint"]),
                "ports_scanned": list(self.rng.choice([21, 22, 23, 25, 53, 80, 110, 143, 443, 993, 995], 
                                                    size=self.rng.randint(3, 8), replace=False)),
                "scan_rate": self.rng.uniform(10, 100)  # packets per second
            }
        )

    def _generate_r2l_attack(self, graph_data: GraphData, target_nodes: List[int]) -> AttackInstance:
        """Generate a remote-to-local attack."""
        
        # External attacker (single source)
        source_nodes = [self.rng.choice(graph_data.x.shape[0])]
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.R2L],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(10, 60),
            duration=self.rng.uniform(300, 3600),  # 5 minutes to 1 hour
            intensity=self.rng.uniform(0.4, 0.8),
            metadata={
                "attack_vector": self.rng.choice(["brute_force", "exploit", "phishing", "watering_hole"]),
                "target_services": self.rng.choice(["ssh", "ftp", "http", "telnet"]),
                "credential_attempts": self.rng.randint(100, 10000),
                "success_rate": self.rng.uniform(0.001, 0.01)
            }
        )

    def _generate_u2r_attack(self, graph_data: GraphData, source_nodes: List[int]) -> AttackInstance:
        """Generate a user-to-root privilege escalation attack."""
        
        # Targets are typically the same nodes (internal privilege escalation)
        target_nodes = source_nodes.copy()
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.U2R],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(60, 300),
            duration=self.rng.uniform(60, 1800),  # 1-30 minutes
            intensity=self.rng.uniform(0.2, 0.6),  # Stealthy
            metadata={
                "escalation_method": self.rng.choice(["buffer_overflow", "privilege_bug", "rootkit", "kernel_exploit"]),
                "target_process": self.rng.choice(["suid_binary", "system_service", "kernel_module"]),
                "stealth_level": self.rng.uniform(0.7, 1.0)
            }
        )

    def _generate_lateral_movement(self, 
                                 graph_data: GraphData, 
                                 source_nodes: List[int],
                                 max_spread: int) -> AttackInstance:
        """Generate lateral movement attack."""
        
        # Find network neighbors of compromised nodes
        target_nodes = []
        edge_index = graph_data.edge_index
        
        for source in source_nodes:
            # Find neighbors
            neighbors = edge_index[1][edge_index[0] == source].tolist()
            target_nodes.extend(neighbors)
        
        # Remove duplicates and limit spread
        target_nodes = list(set(target_nodes))[:max_spread]
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.LATERAL_MOVEMENT],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(300, 1800),  # After initial compromise
            duration=self.rng.uniform(1800, 7200),  # 30 minutes to 2 hours
            intensity=self.rng.uniform(0.1, 0.4),   # Very stealthy
            metadata={
                "movement_technique": self.rng.choice(["pass_the_hash", "golden_ticket", "rdp_hijack", "smb_relay"]),
                "credentials_harvested": self.rng.randint(1, 10),
                "tools_used": ["mimikatz", "psexec", "wmi", "powershell"][:self.rng.randint(1, 4)]
            }
        )

    def _generate_botnet_communication(self, graph_data: GraphData, bot_nodes: List[int]) -> AttackInstance:
        """Generate botnet communication pattern."""
        
        # C2 server (external node)
        c2_node = self.rng.choice(graph_data.x.shape[0])
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.BOTNET],
            source_nodes=bot_nodes,
            target_nodes=[c2_node],
            start_time=self.rng.uniform(0, 3600),
            duration=self.rng.uniform(3600, 86400),  # 1-24 hours
            intensity=self.rng.uniform(0.1, 0.3),    # Low intensity, persistent
            metadata={
                "c2_protocol": self.rng.choice(["http", "https", "dns", "irc", "p2p"]),
                "beacon_interval": self.rng.uniform(300, 3600),  # 5 minutes to 1 hour
                "payload_type": self.rng.choice(["commands", "updates", "stolen_data"]),
                "encryption": True,
                "domain_generation": self.rng.choice([True, False])
            }
        )

    def _generate_data_exfiltration(self, graph_data: GraphData, source_nodes: List[int]) -> AttackInstance:
        """Generate data exfiltration attack."""
        
        # External destination
        target_nodes = [self.rng.choice(graph_data.x.shape[0])]
        
        return AttackInstance(
            pattern=self.attack_patterns[AttackType.DATA_EXFILTRATION],
            source_nodes=source_nodes,
            target_nodes=target_nodes,
            start_time=self.rng.uniform(1800, 7200),  # After establishing presence
            duration=self.rng.uniform(600, 3600),     # 10 minutes to 1 hour
            intensity=self.rng.uniform(0.3, 0.7),
            metadata={
                "exfil_method": self.rng.choice(["ftp", "http_post", "dns_tunneling", "email", "cloud_storage"]),
                "data_volume_mb": self.rng.uniform(10, 1000),
                "compression": True,
                "encryption": True,
                "staging_location": self.rng.choice(["temp_dir", "system_folder", "registry"])
            }
        )

    def _compute_node_degrees(self, graph_data: GraphData) -> np.ndarray:
        """Compute node degrees from edge index."""
        
        num_nodes = graph_data.x.shape[0]
        degrees = np.zeros(num_nodes)
        
        if graph_data.edge_index.shape[1] > 0:
            edge_index = graph_data.edge_index.numpy()
            for i in range(num_nodes):
                degrees[i] = np.sum(edge_index[0] == i) + np.sum(edge_index[1] == i)
        
        return degrees

    def _select_vulnerable_targets(self, graph_data: GraphData, max_targets: int = 5) -> List[int]:
        """Select nodes that are likely to be vulnerable (servers, high-degree nodes)."""
        
        degrees = self._compute_node_degrees(graph_data)
        
        # Select high-degree nodes as they're more likely to be servers
        num_targets = min(max_targets, max(1, len(degrees) // 10))
        high_degree_indices = np.argsort(degrees)[-num_targets:]
        
        return high_degree_indices.tolist()

    def _select_random_internal_nodes(self, graph_data: GraphData, max_nodes: int = 5) -> List[int]:
        """Select random internal nodes."""
        
        num_nodes = min(max_nodes, graph_data.x.shape[0])
        selected_nodes = self.rng.choice(
            graph_data.x.shape[0], 
            size=num_nodes, 
            replace=False
        )
        
        return selected_nodes.tolist()

    def inject_realistic_attacks(self, 
                               graphs: List[GraphData],
                               attack_ratio: float = 0.15,
                               temporal_correlation: bool = True) -> List[GraphData]:
        """
        Inject realistic attacks into a list of graphs (e.g., temporal snapshots).
        
        Args:
            graphs: List of graph data objects
            attack_ratio: Target proportion of nodes involved in attacks
            temporal_correlation: Whether attacks should be correlated across time
            
        Returns:
            List of graphs with realistic attack patterns injected
        """
        
        if not graphs:
            return graphs
        
        modified_graphs = []
        
        if temporal_correlation and len(graphs) > 1:
            # Generate correlated attack campaign across time windows
            modified_graphs = self._inject_temporal_campaign(graphs, attack_ratio)
        else:
            # Generate independent attacks for each graph
            for i, graph in enumerate(graphs):
                modified_graph, _ = self.generate_attack_scenario(
                    graph, 
                    attack_ratio=attack_ratio,
                    multi_stage=False
                )
                modified_graphs.append(modified_graph)
        
        self.logger.info(f"Injected realistic attacks into {len(modified_graphs)} graphs")
        return modified_graphs

    def _inject_temporal_campaign(self, 
                                graphs: List[GraphData], 
                                attack_ratio: float) -> List[GraphData]:
        """Inject a correlated attack campaign across temporal graphs."""
        
        modified_graphs = []
        persistent_compromised = set()
        
        for i, graph in enumerate(graphs):
            # Determine attack types for this time window
            if i == 0:
                # Initial reconnaissance and compromise
                attack_types = [AttackType.PROBE, AttackType.R2L]
            elif i < len(graphs) // 3:
                # Escalation and lateral movement
                attack_types = [AttackType.U2R, AttackType.LATERAL_MOVEMENT]
            elif i < 2 * len(graphs) // 3:
                # Persistence and data gathering
                attack_types = [AttackType.BOTNET, AttackType.DATA_EXFILTRATION]
            else:
                # Final stages - DoS or continued exfiltration
                attack_types = [AttackType.DOS, AttackType.DATA_EXFILTRATION]
            
            # Generate attacks with temporal consistency
            modified_graph, attack_instances = self.generate_attack_scenario(
                graph,
                attack_types=attack_types,
                attack_ratio=attack_ratio,
                multi_stage=True
            )
            
            # Track persistent compromises
            for instance in attack_instances:
                persistent_compromised.update(instance.source_nodes)
                if instance.pattern.attack_type in [AttackType.R2L, AttackType.U2R, AttackType.LATERAL_MOVEMENT]:
                    persistent_compromised.update(instance.target_nodes)
            
            # Ensure previously compromised nodes remain compromised
            if persistent_compromised and i > 0:
                for node_id in persistent_compromised:
                    if node_id < modified_graph.y_node.shape[0]:
                        modified_graph.y_node[node_id] = 1
            
            modified_graphs.append(modified_graph)
        
        return modified_graphs


# Utility functions for integration
def inject_realistic_attacks_into_graphs(graphs: List[GraphData], 
                                       attack_ratio: float = 0.15,
                                       seed: int = 42) -> List[GraphData]:
    """
    Convenience function to inject realistic attacks into graph datasets.
    
    Args:
        graphs: List of graph data objects
        attack_ratio: Proportion of nodes to involve in attacks
        seed: Random seed for reproducibility
        
    Returns:
        List of graphs with realistic attack patterns
    """
    
    generator = AttackPatternGenerator(seed=seed)
    return generator.inject_realistic_attacks(graphs, attack_ratio=attack_ratio)


if __name__ == "__main__":
    # Test the attack pattern generator
    from ..data.network_graph_builder import NetworkGraphBuilder
    from ..data.real import RealDatasetLoader
    
    # Create a test graph
    loader = RealDatasetLoader("nsl_kdd")
    loader.download_dataset()
    df = loader.load_dataset()
    
    if df is not None:
        # Preprocess and create graph
        X, y, _ = loader.preprocess_dataset(df)
        df['is_attack'] = y
        
        builder = NetworkGraphBuilder()
        flows = builder.load_nsl_kdd_flows(df.head(1000))  # Use subset for testing
        hosts = builder.build_host_profiles()
        graph = builder.create_network_graph()
        
        # Generate realistic attacks
        generator = AttackPatternGenerator(seed=42)
        attacked_graph, attack_instances = generator.generate_attack_scenario(
            graph, attack_ratio=0.2, multi_stage=True
        )
        
        print(f"Original graph: {graph.y_node.sum()}/{graph.y_node.shape[0]} attack nodes")
        print(f"Attacked graph: {attacked_graph.y_node.sum()}/{attacked_graph.y_node.shape[0]} attack nodes")
        print(f"Generated {len(attack_instances)} attack instances:")
        
        for attack in attack_instances:
            print(f"  - {attack.pattern.name}: {len(attack.source_nodes)} sources -> {len(attack.target_nodes)} targets")
