#!/usr/bin/env python3
"""
Cisco Secure Workload Networks Dataset Loader
Handles the 22 disjoint enterprise network graphs from SNAP/UCI.
"""

import os
import pandas as pd
import numpy as np
import torch
import networkx as nx
from typing import List, Tuple, Dict, Optional
import zipfile
import urllib.request
from pathlib import Path
import json

from ..utils.common import GraphData, ensure_dir, get_logger


class CiscoDatasetLoader:
    """Loader for Cisco Secure Workload Networks dataset."""

    def __init__(self, data_path: str = "data/cisco"):
        self.data_path = Path(data_path)
        self.logger = get_logger("cisco_data")
        ensure_dir(str(self.data_path))

        # Dataset URLs
        self.dataset_urls = {
            "snap": "https://snap.stanford.edu/data/CiscoSecureWorkload_22_networks.zip",
            "uci": "https://archive.ics.uci.edu/static/public/735/cisco+secure+workload+networks+of+computing+hosts.zip"
        }

    def download_dataset(self, source: str = "snap") -> bool:
        """Download Cisco dataset from SNAP or UCI."""
        
        if source not in self.dataset_urls:
            self.logger.error(f"Unknown source: {source}. Use 'snap' or 'uci'")
            return False

        url = self.dataset_urls[source]
        zip_path = self.data_path / f"cisco_dataset_{source}.zip"
        
        # Check if already downloaded
        if zip_path.exists():
            self.logger.info(f"Dataset already downloaded: {zip_path}")
        else:
            try:
                self.logger.info(f"Downloading Cisco dataset from {source.upper()}...")
                urllib.request.urlretrieve(url, zip_path)
                self.logger.info(f"Downloaded to: {zip_path}")
            except Exception as e:
                self.logger.error(f"Failed to download dataset: {e}")
                return False

        # Extract if not already extracted
        extract_path = self.data_path / "extracted"
        if not extract_path.exists():
            try:
                self.logger.info("Extracting dataset...")
                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_path)
                self.logger.info(f"Extracted to: {extract_path}")
            except Exception as e:
                self.logger.error(f"Failed to extract dataset: {e}")
                return False

        return True

    def load_enterprise_graphs(self) -> List[GraphData]:
        """
        Load all enterprise graphs from the Cisco dataset.
        First tries to load preprocessed data, falls back to raw processing if needed.
        
        Returns:
            List of GraphData objects representing enterprise networks
        """
        
        # First, try to load preprocessed data
        processed_file = self.data_path / "cisco_graphs_processed.pkl"
        if processed_file.exists():
            return self._load_preprocessed_graphs(processed_file)
        
        # Fall back to raw processing
        self.logger.info("Preprocessed data not found, processing raw .gz files...")
        self.logger.info("For faster loading, run the preprocessing script first:")
        self.logger.info("  python -m gatv2_ns3_ids.scripts.preprocess_cisco_dataset --cisco_data_path data/cisco")
        
        return self._load_raw_graphs()

    def _load_preprocessed_graphs(self, processed_file: Path) -> List[GraphData]:
        """Load graphs from preprocessed pickle file."""
        import pickle
        
        self.logger.info(f"Loading preprocessed graphs from {processed_file}")
        
        try:
            with open(processed_file, 'rb') as f:
                processed_graphs = pickle.load(f)
            
            # Convert back to GraphData objects
            enterprise_graphs = []
            for graph_data in processed_graphs:
                graph = GraphData(
                    x=graph_data['x'],
                    edge_index=graph_data['edge_index'],
                    edge_attr=graph_data['edge_attr'],
                    y_node=graph_data['y_node'],
                    graph_id=graph_data['graph_id'],
                    window_idx=graph_data.get('window_idx', 0)
                )
                enterprise_graphs.append(graph)
            
            self.logger.info(f"Successfully loaded {len(enterprise_graphs)} preprocessed graphs")
            return enterprise_graphs
            
        except Exception as e:
            self.logger.error(f"Failed to load preprocessed data: {e}")
            self.logger.info("Falling back to raw processing...")
            return self._load_raw_graphs()

    def _load_raw_graphs(self) -> List[GraphData]:
        """Load graphs from raw .gz files (slower)."""
        
        # Check for the actual extracted directory structure
        cisco_dir = self.data_path / "Cisco_22_networks"
        if not cisco_dir.exists():
            self.logger.error("Dataset not found. Please run download_dataset() first.")
            return []

        # Find graph files from the 20 graphs directory (main dataset)
        graph_files = []
        
        # Load from dir_20_graphs (4 days of data)
        day_dirs = ["dir_day1", "dir_day2", "dir_day3", "dir_day4"]
        for day_dir in day_dirs:
            day_path = cisco_dir / "dir_20_graphs" / day_dir
            if day_path.exists():
                for gz_file in day_path.glob("*.txt.gz"):
                    graph_files.append(gz_file)

        if not graph_files:
            self.logger.error("No graph files found in extracted dataset")
            return []

        self.logger.info(f"Found {len(graph_files)} potential graph files")

        enterprise_graphs = []
        
        # Load each enterprise graph
        for i, graph_file in enumerate(sorted(graph_files)[:22]):  # Limit to 22 as per dataset description
            try:
                graph = self._load_single_graph(graph_file, enterprise_id=i)
                if graph is not None:
                    enterprise_graphs.append(graph)
                    self.logger.info(f"Loaded enterprise {i}: {graph.x.shape[0]} nodes, {graph.edge_index.shape[1]} edges")
            except Exception as e:
                self.logger.warning(f"Failed to load graph {graph_file}: {e}")
                continue

        if len(enterprise_graphs) < 22:
            self.logger.warning(f"Expected 22 enterprise graphs, but loaded {len(enterprise_graphs)}")

        self.logger.info(f"Successfully loaded {len(enterprise_graphs)} enterprise graphs")
        return enterprise_graphs

    def _load_single_graph(self, graph_file: Path, enterprise_id: int) -> Optional[GraphData]:
        """Load a single enterprise graph from file."""
        
        try:
            # Handle gzipped files
            if graph_file.suffix.lower() == ".gz":
                return self._load_cisco_format(graph_file, enterprise_id)
            else:
                self.logger.warning(f"Unsupported file format: {graph_file.suffix}")
                return None
                
        except Exception as e:
            self.logger.error(f"Error loading {graph_file}: {e}")
            return None

    def _load_cisco_format(self, graph_file: Path, enterprise_id: int) -> GraphData:
        """Load graph from Cisco specific format."""
        import gzip
        
        edges = []
        node_set = set()
        
        # Read the gzipped file
        with gzip.open(graph_file, 'rt') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                
                # Parse Cisco format: g<graph_id> <src> <dst> <port_info>
                parts = line.split('\t')
                if len(parts) >= 3:
                    try:
                        # Extract source and destination nodes
                        src = int(parts[1])
                        dst = int(parts[2])
                        edges.append((src, dst))
                        node_set.add(src)
                        node_set.add(dst)
                    except ValueError:
                        continue
        
        if not edges:
            raise ValueError("No valid edges found in Cisco format")

        # Create NetworkX graph
        G = nx.DiGraph()
        G.add_edges_from(edges)
        
        return self._networkx_to_graph_data(G, enterprise_id)

    def _load_edge_list_format(self, graph_file: Path, enterprise_id: int) -> GraphData:
        """Load graph from edge list format."""
        
        # Read edge list
        edges = []
        with open(graph_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    parts = line.split()
                    if len(parts) >= 2:
                        try:
                            src, dst = int(parts[0]), int(parts[1])
                            edges.append((src, dst))
                        except ValueError:
                            continue

        if not edges:
            raise ValueError("No valid edges found")

        # Build graph
        G = nx.DiGraph()
        G.add_edges_from(edges)
        
        # Convert to PyTorch Geometric format
        return self._networkx_to_graph_data(G, enterprise_id)

    def _load_tabular_format(self, graph_file: Path, enterprise_id: int) -> GraphData:
        """Load graph from tabular format (CSV/TSV)."""
        
        # Try different separators
        for sep in ['\t', ',', ' ']:
            try:
                df = pd.read_csv(graph_file, sep=sep, header=None)
                if df.shape[1] >= 2:
                    break
            except:
                continue
        else:
            raise ValueError("Could not parse tabular format")

        # Extract edges
        edges = []
        for _, row in df.iterrows():
            try:
                src, dst = int(row.iloc[0]), int(row.iloc[1])
                edges.append((src, dst))
            except (ValueError, IndexError):
                continue

        if not edges:
            raise ValueError("No valid edges found in tabular data")

        # Build graph
        G = nx.DiGraph()
        G.add_edges_from(edges)
        
        return self._networkx_to_graph_data(G, enterprise_id)

    def _load_json_format(self, graph_file: Path, enterprise_id: int) -> GraphData:
        """Load graph from JSON format."""
        
        with open(graph_file, 'r') as f:
            data = json.load(f)

        # Handle different JSON structures
        if 'edges' in data:
            edges = [(e['source'], e['target']) for e in data['edges']]
        elif 'links' in data:
            edges = [(e['source'], e['target']) for e in data['links']]
        elif isinstance(data, list):
            edges = [(e[0], e[1]) for e in data if len(e) >= 2]
        else:
            raise ValueError("Unsupported JSON structure")

        if not edges:
            raise ValueError("No edges found in JSON data")

        # Build graph
        G = nx.DiGraph()
        G.add_edges_from(edges)
        
        return self._networkx_to_graph_data(G, enterprise_id)

    def _networkx_to_graph_data(self, G: nx.DiGraph, enterprise_id: int) -> GraphData:
        """Convert NetworkX graph to PyTorch Geometric GraphData."""
        
        # Relabel nodes to be contiguous integers starting from 0
        mapping = {node: i for i, node in enumerate(G.nodes())}
        G = nx.relabel_nodes(G, mapping)
        
        num_nodes = G.number_of_nodes()
        num_edges = G.number_of_edges()
        
        if num_nodes == 0:
            raise ValueError("Graph has no nodes")

        # Create edge index
        if num_edges > 0:
            edge_list = list(G.edges())
            edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        else:
            edge_index = torch.zeros((2, 0), dtype=torch.long)

        # Create node features (enterprise network characteristics)
        node_features = self._create_enterprise_node_features(G, num_nodes)
        
        # Create edge features (communication characteristics)
        edge_features = self._create_enterprise_edge_features(G, num_edges)

        # No ground truth labels initially (will be added by synthetic injection)
        node_labels = torch.zeros(num_nodes, dtype=torch.long)

        return GraphData(
            x=node_features,
            edge_index=edge_index,
            edge_attr=edge_features,
            y_node=node_labels,
            graph_id=f"enterprise_{enterprise_id}",
            window_idx=0,
            metadata={
                "enterprise_id": enterprise_id,
                "num_nodes": num_nodes,
                "num_edges": num_edges,
                "source": "cisco_dataset"
            }
        )

    def _create_enterprise_node_features(self, G: nx.DiGraph, num_nodes: int) -> torch.Tensor:
        """Create realistic and discriminative node features for enterprise networks."""
        
        features = []
        
        # Pre-compute centrality measures for efficiency
        try:
            if num_nodes <= 1000:
                betweenness_centrality = nx.betweenness_centrality(G, k=min(100, num_nodes))
                closeness_centrality = nx.closeness_centrality(G)
            else:
                betweenness_centrality = {}
                closeness_centrality = {}
        except:
            betweenness_centrality = {}
            closeness_centrality = {}
        
        # Create deterministic but realistic features based on network topology
        np.random.seed(42)  # Ensure reproducible features
        
        for node in range(num_nodes):
            # Basic network statistics
            in_degree = G.in_degree(node) if G.has_node(node) else 0
            out_degree = G.out_degree(node) if G.has_node(node) else 0
            total_degree = in_degree + out_degree
            
            # Clustering coefficient
            try:
                clustering = nx.clustering(G.to_undirected(), node) if G.has_node(node) else 0.0
            except:
                clustering = 0.0
            
            # Centrality measures
            betweenness = betweenness_centrality.get(node, 0.0)
            closeness = closeness_centrality.get(node, 0.0)

            # Deterministic enterprise features based on topology
            # Use node ID and degree to create consistent but varied features
            node_seed = (node * 7919) % 1000000  # Large prime for good distribution
            np.random.seed(node_seed)
            
            # Protocol distribution (based on node role inferred from degree)
            if total_degree > np.percentile([G.degree(n) for n in G.nodes()], 90):
                # High-degree nodes (servers) - more TCP, less UDP
                tcp_ratio = np.random.beta(5, 1)  # Heavily biased towards TCP
                udp_ratio = np.random.beta(1, 3)  # Lower UDP
            elif total_degree > np.percentile([G.degree(n) for n in G.nodes()], 50):
                # Medium-degree nodes (workstations) - balanced
                tcp_ratio = np.random.beta(3, 2)
                udp_ratio = np.random.beta(2, 3)
            else:
                # Low-degree nodes (endpoints) - more varied
                tcp_ratio = np.random.beta(2, 2)
                udp_ratio = np.random.beta(2, 2)
            
            icmp_ratio = np.random.exponential(0.02)  # Small ICMP traffic
            
            # Normalize protocol ratios
            total_proto = tcp_ratio + udp_ratio + icmp_ratio
            tcp_ratio /= total_proto
            udp_ratio /= total_proto
            icmp_ratio /= total_proto
            
            # Traffic volume indicators (based on centrality and degree)
            base_traffic = 8 + 2 * np.log1p(total_degree)  # Higher degree = more traffic
            centrality_boost = 2 * (betweenness + closeness)
            
            bytes_in = np.random.lognormal(base_traffic + centrality_boost, 1.5)
            bytes_out = np.random.lognormal(base_traffic + centrality_boost * 0.8, 1.5)
            
            # Temporal patterns (deterministic based on node characteristics)
            # Servers have more consistent activity, endpoints more bursty
            if total_degree > np.percentile([G.degree(n) for n in G.nodes()], 80):
                business_hours_activity = 0.8 + 0.2 * np.random.random()  # Servers always active
                off_hours_activity = 0.4 + 0.3 * np.random.random()
            else:
                business_hours_activity = 0.3 + 0.6 * np.random.random()  # Endpoints more variable
                off_hours_activity = 0.1 + 0.2 * np.random.random()
            
            # Connection patterns (based on node role)
            if betweenness > 0.01:  # Critical nodes have better reliability
                failed_conn_ratio = np.random.beta(1, 20)  # Very low failure rate
                avg_session_duration = np.random.lognormal(3, 0.8)  # Longer sessions
            else:
                failed_conn_ratio = np.random.beta(2, 15)  # Slightly higher failure rate
                avg_session_duration = np.random.lognormal(2, 1.2)  # More variable sessions              
            # Port usage patterns (simulate service types)
            common_ports = np.random.choice([80, 443, 22, 3389, 445, 53], size=3, replace=False)
            port_diversity = len(set(common_ports)) / 3.0
            
            # Security-related features (baseline - will be modified by attack injection)
            # These will be the key discriminative features for attack detection
            
            # Authentication patterns
            auth_failures = np.random.poisson(0.5)  # Low baseline auth failures
            privilege_escalations = 0.0  # No baseline privilege escalations
            
            # Network behavior anomalies
            dns_queries_per_hour = np.random.poisson(10 + total_degree)  # Based on activity
            unique_destinations = min(total_degree + np.random.poisson(5), 50)  # Reasonable limit
            
            # Traffic timing patterns
            off_hours_traffic_ratio = off_hours_activity / (business_hours_activity + 1e-6)
            burst_traffic_events = np.random.poisson(0.2)  # Low baseline bursts
            
            # Connection behavior
            short_lived_connections = np.random.beta(2, 8)  # Most connections are normal length
            connection_retry_rate = failed_conn_ratio * 2  # Retries related to failures
            
            # Data transfer patterns
            upload_download_ratio = bytes_out / (bytes_in + 1e-6)
            large_file_transfers = np.random.poisson(0.1)  # Rare large transfers
            
            # Compile all features
            node_features = [
                # Topology features (7 features)
                float(in_degree), float(out_degree), float(total_degree),
                float(clustering), float(betweenness), float(closeness),
                float(port_diversity),
                
                # Protocol features (3 features)
                float(tcp_ratio), float(udp_ratio), float(icmp_ratio),
                
                # Traffic volume features (2 features)
                float(np.log1p(bytes_in)), float(np.log1p(bytes_out)),
                
                # Temporal features (3 features)
                float(business_hours_activity), float(off_hours_activity),
                float(off_hours_traffic_ratio),
                
                # Connection features (4 features)
                float(failed_conn_ratio), float(np.log1p(avg_session_duration)),
                float(short_lived_connections), float(connection_retry_rate),
                
                # Security features (7 features) - KEY FOR ATTACK DETECTION
                float(auth_failures), float(privilege_escalations),
                float(np.log1p(dns_queries_per_hour)), float(unique_destinations),
                float(burst_traffic_events), float(upload_download_ratio),
                float(large_file_transfers)
            ]
            
            features.append(node_features)
        
        # Reset random seed
        np.random.seed(None)
        
        return torch.tensor(features, dtype=torch.float32)

    def _create_enterprise_edge_features(self, G: nx.DiGraph, num_edges: int) -> Optional[torch.Tensor]:
        """Create realistic edge features for enterprise communications."""
        
        if num_edges == 0:
            return None

        features = []
        
        for src, dst in G.edges():
            # Communication volume
            bytes_transferred = np.random.lognormal(8, 2)  # Bytes in this flow
            packets_count = np.random.poisson(max(1, bytes_transferred / 1500))  # Packets
            
            # Timing characteristics
            duration = np.random.lognormal(1, 1)  # Flow duration
            inter_arrival_time = np.random.exponential(0.1)  # Between packets
            
            # Protocol characteristics
            protocol_type = np.random.choice([6, 17, 1], p=[0.7, 0.25, 0.05])  # TCP, UDP, ICMP
            
            # Port information (simulate common enterprise services)
            if protocol_type == 6:  # TCP
                dst_port = np.random.choice([80, 443, 22, 3389, 445, 993, 995], p=[0.3, 0.3, 0.1, 0.1, 0.05, 0.1, 0.05])
            elif protocol_type == 17:  # UDP
                dst_port = np.random.choice([53, 123, 161, 514, 1194], p=[0.4, 0.2, 0.2, 0.1, 0.1])
            else:  # ICMP
                dst_port = 0
                
            src_port = np.random.randint(1024, 65536) if protocol_type != 1 else 0
            
            # Connection flags (simulate TCP flags)
            syn_flag = 1.0 if protocol_type == 6 else 0.0
            ack_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.1 else 0.0
            fin_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.8 else 0.0
            rst_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.95 else 0.0
            
            # Quality metrics
            jitter = np.random.exponential(0.01)  # Network jitter
            packet_loss = np.random.beta(1, 100)  # Packet loss rate
            
            edge_feature_vector = [
                np.log1p(bytes_transferred),    # Log bytes transferred
                np.log1p(packets_count),        # Log packet count
                np.log1p(duration),             # Log flow duration
                inter_arrival_time,             # Inter-arrival time
                protocol_type / 17.0,           # Normalized protocol
                src_port / 65536.0,             # Normalized source port
                dst_port / 65536.0,             # Normalized destination port
                syn_flag, ack_flag, fin_flag, rst_flag,  # TCP flags
                jitter,                         # Network jitter
                packet_loss,                    # Packet loss rate
                float(bytes_transferred > np.median([1000, 10000])),  # Large transfer indicator
            ]
            
            features.append(edge_feature_vector)

        return torch.tensor(features, dtype=torch.float32)

    def get_ground_truth_groupings(self) -> Optional[Dict[int, Dict[str, List[int]]]]:
        """
        Load ground truth groupings for the 2 enterprises that have them.
        
        Returns:
            Dictionary mapping enterprise_id to grouping information
        """
        
        extract_path = self.data_path / "extracted"
        
        # Look for ground truth files
        gt_files = list(extract_path.rglob("*ground*truth*")) + list(extract_path.rglob("*grouping*"))
        
        if not gt_files:
            self.logger.warning("No ground truth grouping files found")
            return None

        groupings = {}
        
        for gt_file in gt_files:
            try:
                # Parse ground truth file (format may vary)
                with open(gt_file, 'r') as f:
                    content = f.read()
                    
                # This is a placeholder - actual parsing would depend on file format
                # According to dataset description, only 2 out of 22 graphs have ground truth
                self.logger.info(f"Found ground truth file: {gt_file}")
                # TODO: Implement actual parsing based on file format
                
            except Exception as e:
                self.logger.warning(f"Failed to parse ground truth file {gt_file}: {e}")

        return groupings if groupings else None


if __name__ == "__main__":
    # Example usage
    loader = CiscoDatasetLoader()
    
    # Download dataset
    if loader.download_dataset(source="snap"):
        # Load enterprise graphs
        graphs = loader.load_enterprise_graphs()
        print(f"Loaded {len(graphs)} enterprise graphs")
        
        # Print statistics
        for i, graph in enumerate(graphs[:5]):  # First 5 graphs
            print(f"Enterprise {i}: {graph.x.shape[0]} nodes, {graph.edge_index.shape[1]} edges")
            print(f"  Node features: {graph.x.shape[1]}, Edge features: {graph.edge_attr.shape[1] if graph.edge_attr is not None else 0}")
