#!/usr/bin/env python3
"""
Network Graph Builder for creating realistic network topologies from real data.
Converts network flow data into meaningful graph structures that preserve 
network semantics and attack patterns.
"""

import numpy as np
import pandas as pd
import networkx as nx
import torch
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.cluster import KMeans
import ipaddress
import warnings

from ..utils.common import GraphData, get_logger


@dataclass
class NetworkFlow:
    """Represents a network flow with source/destination and features."""
    src_ip: str
    dst_ip: str
    src_port: int
    dst_port: int
    protocol: str
    bytes_sent: int
    packets_sent: int
    duration: float
    flags: List[str]
    timestamp: float
    is_attack: bool = False
    attack_type: str = "normal"


@dataclass
class NetworkHost:
    """Represents a network host with aggregated statistics."""
    ip_address: str
    host_id: int
    subnet: str
    total_bytes_sent: int = 0
    total_bytes_received: int = 0
    total_flows_sent: int = 0
    total_flows_received: int = 0
    unique_destinations: int = 0
    unique_sources: int = 0
    avg_packet_size: float = 0.0
    protocols_used: List[str] = None
    ports_accessed: List[int] = None
    is_server: bool = False
    is_compromised: bool = False
    attack_flows: int = 0
    
    def __post_init__(self):
        if self.protocols_used is None:
            self.protocols_used = []
        if self.ports_accessed is None:
            self.ports_accessed = []


class NetworkGraphBuilder:
    """
    Builds realistic network graphs from flow-based intrusion detection datasets.
    Preserves network semantics and creates meaningful node/edge representations.
    """
    
    def __init__(self, 
                 subnet_aggregation: bool = True,
                 temporal_windows: bool = True,
                 window_size_minutes: int = 5,
                 min_flow_threshold: int = 2):
        """
        Initialize the network graph builder.
        
        Args:
            subnet_aggregation: Whether to aggregate hosts by subnet
            temporal_windows: Whether to create temporal graph snapshots
            window_size_minutes: Size of temporal windows in minutes
            min_flow_threshold: Minimum flows required to create an edge
        """
        self.subnet_aggregation = subnet_aggregation
        self.temporal_windows = temporal_windows
        self.window_size_minutes = window_size_minutes
        self.min_flow_threshold = min_flow_threshold
        self.logger = get_logger("network_graph_builder")
        
        # Internal state
        self.hosts: Dict[str, NetworkHost] = {}
        self.flows: List[NetworkFlow] = []
        self.scaler = StandardScaler()
        self.label_encoders = {}

    def load_nsl_kdd_flows(self, df: pd.DataFrame) -> List[NetworkFlow]:
        """
        Convert NSL-KDD dataset to network flows.
        NSL-KDD provides connection records that can be treated as flows.
        """
        flows = []
        
        # NSL-KDD doesn't have real IP addresses, so we'll simulate them
        # based on the connection patterns and service types
        
        for idx, row in df.iterrows():
            # Generate synthetic but realistic IP addresses
            src_ip, dst_ip = self._generate_realistic_ips(row, idx)
            
            # Extract flow information
            flow = NetworkFlow(
                src_ip=src_ip,
                dst_ip=dst_ip,
                src_port=self._map_service_to_port(row.get('service', 'http')),
                dst_port=self._map_service_to_port(row.get('service', 'http')),
                protocol=row.get('protocol_type', 'tcp'),
                bytes_sent=int(row.get('src_bytes', 0)),
                packets_sent=int(row.get('count', 1)),
                duration=float(row.get('duration', 0)),
                flags=self._parse_flags(row.get('flag', 'SF')),
                timestamp=float(idx),  # Use row index as timestamp
                is_attack=row.get('is_attack', 0) == 1,
                attack_type=row.get('label', 'normal')
            )
            
            flows.append(flow)
        
        self.flows = flows
        self.logger.info(f"Loaded {len(flows)} flows from NSL-KDD dataset")
        return flows

    def _generate_realistic_ips(self, row: pd.Series, idx: int) -> Tuple[str, str]:
        """Generate realistic IP addresses based on NSL-KDD connection patterns."""
        
        # Define realistic IP ranges
        internal_subnets = [
            "192.168.1.0/24",   # Common home/office network
            "10.0.0.0/24",      # Corporate internal
            "172.16.0.0/24"     # Private network
        ]
        
        external_subnets = [
            "203.0.113.0/24",   # Documentation/test range
            "198.51.100.0/24",  # Documentation range
            "8.8.8.0/24"        # Public DNS range
        ]
        
        # Determine if this is internal-to-internal or internal-to-external
        service = row.get('service', 'http')
        dst_host_count = row.get('dst_host_count', 1)
        
        # Internal communication indicators
        is_internal_comm = (
            service in ['domain_u', 'private', 'netbios_ns'] or
            dst_host_count > 100 or  # Scanning behavior
            row.get('same_srv_rate', 0) > 0.8
        )
        
        if is_internal_comm:
            # Both IPs from internal ranges
            src_subnet = np.random.choice(internal_subnets)
            dst_subnet = src_subnet if np.random.random() > 0.3 else np.random.choice(internal_subnets)
        else:
            # Mixed internal/external
            src_subnet = np.random.choice(internal_subnets)
            dst_subnet = np.random.choice(external_subnets)
        
        # Generate specific IPs within subnets
        src_network = ipaddress.ip_network(src_subnet, strict=False)
        dst_network = ipaddress.ip_network(dst_subnet, strict=False)
        
        # Use hash of row index for consistency
        src_hash = hash(f"src_{idx}") % (src_network.num_addresses - 2)
        dst_hash = hash(f"dst_{idx}") % (dst_network.num_addresses - 2)
        
        src_ip = str(list(src_network.hosts())[src_hash])
        dst_ip = str(list(dst_network.hosts())[dst_hash])
        
        return src_ip, dst_ip

    def _map_service_to_port(self, service: str) -> int:
        """Map NSL-KDD service names to realistic port numbers."""
        
        service_port_map = {
            'http': 80,
            'smtp': 25,
            'ftp': 21,
            'telnet': 23,
            'ssh': 22,
            'domain_u': 53,
            'domain': 53,
            'pop_3': 110,
            'finger': 79,
            'imap4': 143,
            'nntp': 119,
            'private': 1024,
            'netbios_ns': 137,
            'netbios_dgm': 138,
            'netbios_ssn': 139,
            'ldap': 389,
            'https': 443,
            'shell': 514,
            'login': 513,
            'klogin': 543,
            'kshell': 544,
            'exec': 512,
            'printer': 515,
            'sunrpc': 111,
            'auth': 113,
            'uucp': 540,
            'vmnet': 175,
            'courier': 530,
            'csnet_ns': 105,
            'rje': 77,
            'hostname': 101,
            'iso_tsap': 102,
            'x11': 6000,
            'IRC': 194,
            'Z39_50': 210,
            'sql_net': 1521,
            'bgp': 179,
            'ctf': 84,
            'supdup': 95,
            'uucp_path': 117,
            'netstat': 15,
            'discard': 9,
            'systat': 11,
            'daytime': 13,
            'netbios_ns': 137,
            'echo': 7,
            'tim_i': 752,
            'tftp_u': 69,
            'link': 87,
            'remote_job': 77,
            'gopher': 70,
            'ssh': 22,
            'name': 42,
            'whois': 43,
            'domain_u': 53,
            'mtp': 57,
            'urp_i': 0,  # Unknown
            'pm_dump': 0,  # Unknown
            'red_i': 0,  # Unknown
            'urh_i': 0,  # Unknown
            'http_443': 443,
            'http_2784': 2784,
            'harvest': 0,  # Unknown
            'aol': 5190,
            'http_8001': 8001,
            'ftp_data': 20,
            'ecr_i': 0  # Unknown
        }
        
        return service_port_map.get(service, np.random.randint(1024, 65535))

    def _parse_flags(self, flag_str: str) -> List[str]:
        """Parse NSL-KDD connection flags."""
        
        # NSL-KDD flag meanings
        flag_meanings = {
            'SF': ['SYN', 'FIN'],           # Normal connection
            'S0': ['SYN'],                  # Connection attempt seen, no reply
            'REJ': ['REJECT'],              # Connection attempt rejected
            'RSTR': ['RESET'],              # Connection reset
            'RSTO': ['RESET', 'ORIG'],      # Connection reset by originator
            'SH': ['SYN', 'HALF'],          # SYN seen, no SYN/ACK
            'S1': ['SYN', 'ESTABLISHED'],   # Connection established
            'S2': ['SYN', 'CLOSE'],         # Connection closed
            'S3': ['SYN', 'RESET'],         # Connection reset
            'OTH': ['OTHER']                # Other
        }
        
        return flag_meanings.get(flag_str, ['UNKNOWN'])

    def build_host_profiles(self) -> Dict[str, NetworkHost]:
        """Build comprehensive host profiles from flows."""
        
        hosts = {}
        
        for flow in self.flows:
            # Process source host
            if flow.src_ip not in hosts:
                hosts[flow.src_ip] = NetworkHost(
                    ip_address=flow.src_ip,
                    host_id=len(hosts),
                    subnet=self._get_subnet(flow.src_ip),
                    protocols_used=[],
                    ports_accessed=[]
                )
            
            # Process destination host
            if flow.dst_ip not in hosts:
                hosts[flow.dst_ip] = NetworkHost(
                    ip_address=flow.dst_ip,
                    host_id=len(hosts),
                    subnet=self._get_subnet(flow.dst_ip),
                    protocols_used=[],
                    ports_accessed=[]
                )
            
            src_host = hosts[flow.src_ip]
            dst_host = hosts[flow.dst_ip]
            
            # Update source host statistics
            src_host.total_bytes_sent += flow.bytes_sent
            src_host.total_flows_sent += 1
            if flow.protocol not in src_host.protocols_used:
                src_host.protocols_used.append(flow.protocol)
            if flow.dst_port not in src_host.ports_accessed:
                src_host.ports_accessed.append(flow.dst_port)
            
            # Update destination host statistics
            dst_host.total_bytes_received += flow.bytes_sent
            dst_host.total_flows_received += 1
            if flow.protocol not in dst_host.protocols_used:
                dst_host.protocols_used.append(flow.protocol)
            if flow.src_port not in dst_host.ports_accessed:
                dst_host.ports_accessed.append(flow.src_port)
            
            # Track attack information
            if flow.is_attack:
                src_host.attack_flows += 1
                src_host.is_compromised = True
        
        # Compute derived statistics
        for host in hosts.values():
            if host.total_flows_sent > 0:
                host.avg_packet_size = host.total_bytes_sent / host.total_flows_sent
            
            # Determine if host is likely a server
            host.is_server = (
                host.total_flows_received > host.total_flows_sent * 2 and
                len(host.ports_accessed) > 5
            )
            
            # Count unique connections
            host.unique_destinations = len(set([f.dst_ip for f in self.flows if f.src_ip == host.ip_address]))
            host.unique_sources = len(set([f.src_ip for f in self.flows if f.dst_ip == host.ip_address]))
        
        self.hosts = hosts
        self.logger.info(f"Built profiles for {len(hosts)} hosts")
        
        # Log statistics
        compromised_hosts = sum(1 for h in hosts.values() if h.is_compromised)
        server_hosts = sum(1 for h in hosts.values() if h.is_server)
        self.logger.info(f"Compromised hosts: {compromised_hosts}, Server hosts: {server_hosts}")
        
        return hosts

    def _get_subnet(self, ip_address: str) -> str:
        """Extract subnet from IP address."""
        try:
            ip = ipaddress.ip_address(ip_address)
            if ip.is_private:
                # For private IPs, use /24 subnet
                octets = str(ip).split('.')
                return f"{octets[0]}.{octets[1]}.{octets[2]}.0/24"
            else:
                # For public IPs, use /16 subnet
                octets = str(ip).split('.')
                return f"{octets[0]}.{octets[1]}.0.0/16"
        except:
            return "unknown"

    def create_network_graph(self, 
                           time_window: Optional[Tuple[float, float]] = None,
                           min_flows_for_edge: int = None) -> GraphData:
        """
        Create a network graph from host profiles and flows.
        
        Args:
            time_window: Optional (start, end) time window to filter flows
            min_flows_for_edge: Minimum number of flows required to create an edge
            
        Returns:
            GraphData object representing the network graph
        """
        
        if not self.hosts:
            raise ValueError("No host profiles available. Call build_host_profiles() first.")
        
        min_flows = min_flows_for_edge or self.min_flow_threshold
        
        # Filter flows by time window if specified
        filtered_flows = self.flows
        if time_window:
            start_time, end_time = time_window
            filtered_flows = [f for f in self.flows if start_time <= f.timestamp <= end_time]
        
        # Create mapping from IP to node index
        ip_to_node = {ip: idx for idx, ip in enumerate(self.hosts.keys())}
        num_nodes = len(self.hosts)
        
        # Build edge list from flows
        edge_counts = {}
        edge_attributes = {}
        
        for flow in filtered_flows:
            if flow.src_ip not in ip_to_node or flow.dst_ip not in ip_to_node:
                continue
                
            src_node = ip_to_node[flow.src_ip]
            dst_node = ip_to_node[flow.dst_ip]
            
            # Skip self-loops
            if src_node == dst_node:
                continue
            
            edge_key = (src_node, dst_node)
            
            # Count flows between hosts
            if edge_key not in edge_counts:
                edge_counts[edge_key] = 0
                edge_attributes[edge_key] = {
                    'total_bytes': 0,
                    'total_packets': 0,
                    'total_duration': 0.0,
                    'protocols': set(),
                    'attack_flows': 0,
                    'normal_flows': 0,
                    'avg_packet_size': 0.0,
                    'flags': set()
                }
            
            edge_counts[edge_key] += 1
            attrs = edge_attributes[edge_key]
            attrs['total_bytes'] += flow.bytes_sent
            attrs['total_packets'] += flow.packets_sent
            attrs['total_duration'] += flow.duration
            attrs['protocols'].add(flow.protocol)
            attrs['flags'].update(flow.flags)
            
            if flow.is_attack:
                attrs['attack_flows'] += 1
            else:
                attrs['normal_flows'] += 1
        
        # Filter edges by minimum flow count
        valid_edges = [(src, dst) for (src, dst), count in edge_counts.items() 
                      if count >= min_flows]
        
        self.logger.info(f"Created graph with {num_nodes} nodes and {len(valid_edges)} edges")
        
        # Create edge index tensor
        if valid_edges:
            edge_index = torch.tensor(valid_edges, dtype=torch.long).t()
            # Add reverse edges for undirected graph
            edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
        else:
            edge_index = torch.zeros(2, 0, dtype=torch.long)
        
        # Create node features
        node_features = self._create_node_features()
        
        # Create edge features
        edge_features = self._create_edge_features(valid_edges, edge_attributes)
        
        # Create node labels (binary: compromised or not)
        node_labels = torch.tensor([
            1 if host.is_compromised else 0 
            for host in self.hosts.values()
        ], dtype=torch.long)
        
        # Create graph data object
        graph_data = GraphData(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features,
            y_node=node_labels,
            graph_id=f"network_graph_{hash(str(time_window)) if time_window else 'full'}",
            window_idx=0
        )
        
        # Add metadata
        attack_ratio = float(node_labels.float().mean().item())
        self.logger.info(f"Graph created - Attack ratio: {attack_ratio:.3f}")
        
        return graph_data

    def _create_node_features(self) -> torch.Tensor:
        """Create node feature matrix from host profiles."""
        
        features = []
        
        for host in self.hosts.values():
            # Basic traffic statistics
            feature_vector = [
                float(host.total_bytes_sent),
                float(host.total_bytes_received),
                float(host.total_flows_sent),
                float(host.total_flows_received),
                float(host.unique_destinations),
                float(host.unique_sources),
                float(host.avg_packet_size),
                float(len(host.protocols_used)),
                float(len(host.ports_accessed)),
                float(host.is_server),
                float(host.attack_flows),
            ]
            
            # Protocol distribution (one-hot for common protocols)
            common_protocols = ['tcp', 'udp', 'icmp']
            for protocol in common_protocols:
                feature_vector.append(float(protocol in host.protocols_used))
            
            # Subnet features (encode subnet type)
            subnet_features = self._encode_subnet_features(host.subnet)
            feature_vector.extend(subnet_features)
            
            features.append(feature_vector)
        
        # Convert to tensor and normalize
        features_tensor = torch.tensor(features, dtype=torch.float)
        
        # Apply standardization
        if features_tensor.shape[0] > 1:
            features_normalized = torch.tensor(
                self.scaler.fit_transform(features_tensor.numpy()), 
                dtype=torch.float
            )
        else:
            features_normalized = features_tensor
        
        self.logger.info(f"Created node features with shape: {features_normalized.shape}")
        return features_normalized

    def _encode_subnet_features(self, subnet: str) -> List[float]:
        """Encode subnet information as features."""
        
        features = [0.0, 0.0, 0.0, 0.0]  # [is_private, is_class_a, is_class_b, is_class_c]
        
        try:
            if subnet == "unknown":
                return features
                
            network = ipaddress.ip_network(subnet, strict=False)
            
            if network.is_private:
                features[0] = 1.0
                
                # Classify private networks
                if network.network_address.packed[0] == 10:  # 10.0.0.0/8
                    features[1] = 1.0
                elif network.network_address.packed[:2] == bytes([172, 16]):  # 172.16.0.0/12
                    features[2] = 1.0
                elif network.network_address.packed[:2] == bytes([192, 168]):  # 192.168.0.0/16
                    features[3] = 1.0
        except:
            pass
        
        return features

    def _create_edge_features(self, 
                            edges: List[Tuple[int, int]], 
                            edge_attrs: Dict[Tuple[int, int], Dict]) -> torch.Tensor:
        """Create edge feature matrix."""
        
        if not edges:
            return torch.zeros(0, 10, dtype=torch.float)  # Return empty tensor with correct shape
        
        features = []
        
        for src, dst in edges:
            attrs = edge_attrs.get((src, dst), {})
            
            # Basic flow statistics
            total_flows = attrs.get('attack_flows', 0) + attrs.get('normal_flows', 0)
            feature_vector = [
                float(attrs.get('total_bytes', 0)),
                float(attrs.get('total_packets', 0)),
                float(attrs.get('total_duration', 0)),
                float(total_flows),
                float(attrs.get('attack_flows', 0)),
                float(len(attrs.get('protocols', set()))),
                float(len(attrs.get('flags', set()))),
            ]
            
            # Attack ratio for this edge
            attack_ratio = (attrs.get('attack_flows', 0) / max(1, total_flows))
            feature_vector.append(float(attack_ratio))
            
            # Average packet size
            avg_pkt_size = (attrs.get('total_bytes', 0) / max(1, attrs.get('total_packets', 1)))
            feature_vector.append(float(avg_pkt_size))
            
            # Flow rate (flows per second)
            flow_rate = total_flows / max(0.001, attrs.get('total_duration', 1))
            feature_vector.append(float(flow_rate))
            
            features.append(feature_vector)
        
        # Create tensor for original edges
        edge_features = torch.tensor(features, dtype=torch.float)
        
        # Duplicate features for reverse edges (undirected graph)
        if edge_features.shape[0] > 0:
            edge_features = torch.cat([edge_features, edge_features], dim=0)
        
        self.logger.info(f"Created edge features with shape: {edge_features.shape}")
        return edge_features

    def create_temporal_graphs(self, 
                             num_windows: int = 10,
                             overlap_ratio: float = 0.2) -> List[GraphData]:
        """Create multiple temporal graph snapshots."""
        
        if not self.flows:
            raise ValueError("No flows loaded")
        
        # Determine time range
        timestamps = [f.timestamp for f in self.flows]
        min_time, max_time = min(timestamps), max(timestamps)
        total_duration = max_time - min_time
        
        window_duration = total_duration / num_windows
        overlap_duration = window_duration * overlap_ratio
        
        graphs = []
        
        for i in range(num_windows):
            start_time = min_time + i * (window_duration - overlap_duration)
            end_time = start_time + window_duration
            
            try:
                graph = self.create_network_graph(time_window=(start_time, end_time))
                graph.graph_id = f"temporal_graph_{i}"
                graph.window_idx = i
                graphs.append(graph)
            except Exception as e:
                self.logger.warning(f"Failed to create temporal graph {i}: {e}")
                continue
        
        self.logger.info(f"Created {len(graphs)} temporal graph snapshots")
        return graphs


# Utility functions for integration with existing codebase
def create_realistic_nsl_kdd_graphs(df: pd.DataFrame, 
                                  num_temporal_windows: int = 10,
                                  min_flows_per_edge: int = 2) -> List[GraphData]:
    """
    Create realistic network graphs from NSL-KDD dataset.
    
    Args:
        df: NSL-KDD dataframe with preprocessed data
        num_temporal_windows: Number of temporal snapshots to create
        min_flows_per_edge: Minimum flows required to create an edge
        
    Returns:
        List of GraphData objects representing network snapshots
    """
    
    builder = NetworkGraphBuilder(
        subnet_aggregation=True,
        temporal_windows=True,
        window_size_minutes=5,
        min_flow_threshold=min_flows_per_edge
    )
    
    # Load and process flows
    flows = builder.load_nsl_kdd_flows(df)
    hosts = builder.build_host_profiles()
    
    # Create temporal graphs
    if num_temporal_windows > 1:
        graphs = builder.create_temporal_graphs(num_windows=num_temporal_windows)
    else:
        # Single graph
        graph = builder.create_network_graph()
        graphs = [graph]
    
    return graphs


if __name__ == "__main__":
    # Test the network graph builder
    import pandas as pd
    from ..data.real import RealDatasetLoader
    
    # Load NSL-KDD data
    loader = RealDatasetLoader("nsl_kdd")
    loader.download_dataset()
    df = loader.load_dataset()
    
    if df is not None:
        # Preprocess
        X, y, feature_names = loader.preprocess_dataset(df)
        
        # Add labels back to dataframe for graph building
        df['is_attack'] = y
        
        # Create realistic graphs
        graphs = create_realistic_nsl_kdd_graphs(df, num_temporal_windows=5)
        
        print(f"Created {len(graphs)} realistic network graphs")
        for i, graph in enumerate(graphs):
            print(f"Graph {i}: {graph.x.shape[0]} nodes, {graph.edge_index.shape[1]//2} edges")
            attack_ratio = graph.y_node.float().mean().item()
            print(f"  Attack ratio: {attack_ratio:.3f}")
