#!/usr/bin/env python3
"""
Cisco Dataset Preprocessing Script
Extracts all .gz files and converts them to a more efficient format for training.
Run this script before training to speed up data loading.
"""

import argparse
import gzip
import json
import os
import pickle
from pathlib import Path
from typing import List, Dict, Any
import time

import numpy as np
import torch
from tqdm import tqdm

from ..utils.common import get_logger, ensure_dir


def parse_args():
    parser = argparse.ArgumentParser(description="Preprocess Cisco dataset for faster training")
    parser.add_argument("--cisco_data_path", type=str, required=True,
                       help="Path to Cisco dataset directory")
    parser.add_argument("--output_dir", type=str, default="data/cisco_processed",
                       help="Output directory for processed data")
    parser.add_argument("--max_graphs", type=int, default=22,
                       help="Maximum number of graphs to process (default: 22)")
    return parser.parse_args()


def extract_and_parse_cisco_file(gz_file_path: Path, logger) -> Dict[str, Dict[str, Any]]:
    """Extract and parse a single Cisco .gz file using proven read_graphs.py logic."""
    from collections import defaultdict, Counter
    
    logger.debug(f"Processing {gz_file_path.name}...")
    
    # Use the proven parsing logic from read_graphs.py
    wload_to_graph = {}
    wload_to_port_info = {}
    
    try:
        with gzip.open(gz_file_path, mode='rt') as fopen:
            for line_num, line in enumerate(fopen):
                if line.startswith('#'):  # skip comment lines
                    continue
                parts = line.split()
                if len(parts) < 3:
                    continue
                    
                wload_id = parts[0]  # e.g., 'g2', 'g4', 'g6'
                v1 = parts[1]  # source node
                v2 = parts[2]  # destination node
                
                if wload_id not in wload_to_graph:
                    wload_to_graph[wload_id] = defaultdict(Counter)
                    
                v_to_u = wload_to_graph[wload_id]
                
                # Handle port information
                if len(parts) > 3:
                    port_info = parts[3]
                    ports = port_info.split(',')
                    stats = wload_to_port_info.get(wload_id, None)
                    if stats is None:
                        stats = defaultdict(set)
                        wload_to_port_info[wload_id] = stats
                    
                    ports_added = False
                    for port_tuple in ports:
                        if 'p' not in port_tuple:
                            continue
                        port_part = port_tuple.split('-')[0]
                        if port_part == '':
                            continue
                        stats[port_part].add((v1, v2))
                        ports_added = True
                    
                    if ports_added:
                        v_to_u[v1][v2] += 1
                        v_to_u[v2][v1] += 1
                else:
                    # No port info, just add the edge
                    v_to_u[v1][v2] += 1
                    v_to_u[v2][v1] += 1
                    
    except Exception as e:
        logger.error(f"Failed to process {gz_file_path}: {e}")
        return {}
    
    # Convert the parsed data to our format
    graphs_data = {}
    for wload_id, graph_dict in wload_to_graph.items():
        # Extract edges and nodes
        edges = []
        node_set = set()
        
        for src, targets in graph_dict.items():
            for dst, count in targets.items():
                if src != dst:  # Skip self-loops for now
                    edges.append((src, dst))
                    node_set.add(src)
                    node_set.add(dst)
        
        if not edges:
            logger.warning(f"No valid edges found for graph {wload_id}")
            continue
        
        # Create node mapping (ensure contiguous node IDs starting from 0)
        nodes = sorted(list(node_set))
        try:
            # Try to convert to int for proper sorting
            nodes = sorted([int(n) for n in nodes])
            nodes = [str(n) for n in nodes]  # Convert back to string
        except ValueError:
            # If conversion fails, keep as strings
            nodes = sorted(nodes)
            
        node_mapping = {old_id: new_id for new_id, old_id in enumerate(nodes)}
        
        # Remap edges to use contiguous IDs
        remapped_edges = []
        for src, dst in edges:
            try:
                remapped_edges.append((node_mapping[src], node_mapping[dst]))
            except KeyError:
                logger.warning(f"Node mapping error for edge ({src}, {dst})")
                continue
        
        graphs_data[wload_id] = {
            'graph_id': wload_id,
            'file_source': str(gz_file_path.name),
            'num_nodes': len(nodes),
            'num_edges': len(remapped_edges),
            'edges': remapped_edges,
            'original_node_mapping': {v: k for k, v in node_mapping.items()},  # new_id -> old_id
            'port_info': wload_to_port_info.get(wload_id, {}),
            'statistics': {
                'avg_degree': 2 * len(remapped_edges) / len(nodes) if nodes else 0,
                'max_node_id': max([int(n) for n in nodes]) if nodes else 0,
                'min_node_id': min([int(n) for n in nodes]) if nodes else 0
            }
        }
        
        logger.info(f"Processed graph {wload_id}: {len(nodes)} nodes, {len(remapped_edges)} edges")
    
    return graphs_data


def create_enterprise_features(graph_data: Dict[str, Any]) -> np.ndarray:
    """Create realistic node features for enterprise networks."""
    
    num_nodes = graph_data['num_nodes']
    edges = graph_data['edges']
    
    # Calculate basic network statistics
    in_degree = np.zeros(num_nodes)
    out_degree = np.zeros(num_nodes)
    
    for src, dst in edges:
        out_degree[src] += 1
        in_degree[dst] += 1
    
    total_degree = in_degree + out_degree
    
    # Create comprehensive node features
    features = []
    
    for node_id in range(num_nodes):
        # Basic degree statistics
        in_deg = in_degree[node_id]
        out_deg = out_degree[node_id]
        total_deg = total_degree[node_id]
        
        # Normalized degree features
        max_degree = max(1, np.max(total_degree))
        in_deg_norm = in_deg / max_degree
        out_deg_norm = out_deg / max_degree
        total_deg_norm = total_deg / max_degree
        
        # Communication patterns (simulated)
        tcp_ratio = np.random.beta(2, 1)  # Bias towards TCP
        udp_ratio = 1 - tcp_ratio
        icmp_ratio = np.random.exponential(0.05)  # Small ICMP traffic
        
        # Traffic volume indicators (log-normal distribution)
        bytes_in = np.random.lognormal(10, 2)
        bytes_out = np.random.lognormal(9, 2)
        
        # Temporal patterns (simulate business hours activity)
        hour_of_day = np.random.randint(0, 24)
        business_hours = 1.0 if 9 <= hour_of_day <= 17 else 0.3
        
        # Connection patterns
        failed_conn_ratio = np.random.beta(1, 10)  # Most connections succeed
        avg_session_duration = np.random.lognormal(2, 1)  # Session length
        
        # Service type indicators
        is_server = float(in_deg > out_deg)  # More incoming than outgoing
        is_client = float(out_deg > in_deg)  # More outgoing than incoming
        is_balanced = float(abs(in_deg - out_deg) <= 2)  # Balanced communication
        
        # Port diversity (simulated)
        port_diversity = min(1.0, total_deg / 10.0)  # Normalized port usage
        
        node_feature_vector = [
            in_deg_norm,                        # Normalized in-degree
            out_deg_norm,                       # Normalized out-degree  
            total_deg_norm,                     # Normalized total degree
            tcp_ratio,                          # TCP traffic ratio
            udp_ratio,                          # UDP traffic ratio
            icmp_ratio,                         # ICMP traffic ratio
            np.log1p(bytes_in),                 # Log bytes in
            np.log1p(bytes_out),                # Log bytes out
            business_hours,                     # Business hours activity
            failed_conn_ratio,                  # Failed connection ratio
            np.log1p(avg_session_duration),     # Log session duration
            port_diversity,                     # Port diversity
            np.sin(2 * np.pi * hour_of_day / 24),  # Hour sine encoding
            np.cos(2 * np.pi * hour_of_day / 24),  # Hour cosine encoding
            is_server,                          # Server indicator
            is_client,                          # Client indicator
            is_balanced,                        # Balanced communication indicator
            float(total_deg > np.mean(total_degree)),  # Above average activity
        ]
        
        features.append(node_feature_vector)
    
    return np.array(features, dtype=np.float32)


def create_edge_features(graph_data: Dict[str, Any]) -> np.ndarray:
    """Create realistic edge features for enterprise communications."""
    
    edges = graph_data['edges']
    num_edges = len(edges)
    
    if num_edges == 0:
        return None
    
    features = []
    
    for src, dst in edges:
        # Communication volume
        bytes_transferred = np.random.lognormal(8, 2)  # Bytes in this flow
        packets_count = np.random.poisson(max(1, bytes_transferred / 1500))  # Packets
        
        # Timing characteristics
        duration = np.random.lognormal(1, 1)  # Flow duration
        inter_arrival_time = np.random.exponential(0.1)  # Between packets
        
        # Protocol characteristics
        protocol_type = np.random.choice([6, 17, 1], p=[0.7, 0.25, 0.05])  # TCP, UDP, ICMP
        
        # Port information (simulate common enterprise services)
        if protocol_type == 6:  # TCP
            dst_port = np.random.choice([80, 443, 22, 3389, 445, 993, 995], 
                                      p=[0.3, 0.3, 0.1, 0.1, 0.05, 0.1, 0.05])
        elif protocol_type == 17:  # UDP
            dst_port = np.random.choice([53, 123, 161, 514, 1194], 
                                      p=[0.4, 0.2, 0.2, 0.1, 0.1])
        else:  # ICMP
            dst_port = 0
            
        src_port = np.random.randint(1024, 65536) if protocol_type != 1 else 0
        
        # Connection flags (simulate TCP flags)
        syn_flag = 1.0 if protocol_type == 6 else 0.0
        ack_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.1 else 0.0
        fin_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.8 else 0.0
        rst_flag = 1.0 if protocol_type == 6 and np.random.random() > 0.95 else 0.0
        
        # Quality metrics
        jitter = np.random.exponential(0.01)  # Network jitter
        packet_loss = np.random.beta(1, 100)  # Packet loss rate
        
        edge_feature_vector = [
            np.log1p(bytes_transferred),        # Log bytes transferred
            np.log1p(packets_count),            # Log packet count
            np.log1p(duration),                 # Log flow duration
            inter_arrival_time,                 # Inter-arrival time
            protocol_type / 17.0,               # Normalized protocol
            src_port / 65536.0,                 # Normalized source port
            dst_port / 65536.0,                 # Normalized destination port
            syn_flag, ack_flag, fin_flag, rst_flag,  # TCP flags
            jitter,                             # Network jitter
            packet_loss,                        # Packet loss rate
            float(bytes_transferred > 10000),   # Large transfer indicator
        ]
        
        features.append(edge_feature_vector)
    
    return np.array(features, dtype=np.float32)


def main():
    args = parse_args()
    logger = get_logger("cisco_preprocessor")
    
    # Set up paths
    cisco_data_path = Path(args.cisco_data_path)
    output_dir = Path(args.output_dir)
    ensure_dir(str(output_dir))
    
    # Check if Cisco dataset exists
    cisco_dir = cisco_data_path / "Cisco_22_networks"
    if not cisco_dir.exists():
        logger.error(f"Cisco dataset not found at {cisco_dir}")
        logger.info("Please download the dataset first using:")
        logger.info("  wget https://snap.stanford.edu/data/CiscoSecureWorkload_22_networks.zip")
        return
    
    logger.info("Starting Cisco dataset preprocessing...")
    logger.info(f"Input: {cisco_dir}")
    logger.info(f"Output: {output_dir}")
    
    # Find all .gz files - check both main directory and subdirectories
    graph_files = []
    
    # Check the main dir_20_graphs subdirectories
    day_dirs = ["dir_day1", "dir_day2", "dir_day3", "dir_day4"]
    for day_dir in day_dirs:
        day_path = cisco_dir / "dir_20_graphs" / day_dir
        if day_path.exists():
            for gz_file in sorted(day_path.glob("*.txt.gz")):
                graph_files.append(gz_file)
    
    # Also check the extra graph directory
    extra_graph_dir = cisco_dir / "dir_g22_extra_graph_with_gt" / "dir_edges"
    if extra_graph_dir.exists():
        for gz_file in sorted(extra_graph_dir.glob("*.txt.gz")):
            graph_files.append(gz_file)
    
    # No need to limit files here since we'll extract multiple graphs per file
    
    logger.info(f"Found {len(graph_files)} graph files to process")
    
    # Process each graph file
    processed_graphs = []
    start_time = time.time()
    enterprise_id = 0
    
    for gz_file in tqdm(graph_files, desc="Processing graphs"):
        graphs_from_file = extract_and_parse_cisco_file(gz_file, logger)
        
        for graph_id, graph_data in graphs_from_file.items():
            if graph_data is not None:
                # Create features
                node_features = create_enterprise_features(graph_data)
                edge_features = create_edge_features(graph_data)
                
                # Convert to PyTorch tensors
                edge_index = torch.tensor(graph_data['edges'], dtype=torch.long).t()  # [2, num_edges]
                x = torch.tensor(node_features, dtype=torch.float32)
                edge_attr = torch.tensor(edge_features, dtype=torch.float32) if edge_features is not None else None
                
                # Create labels (initially all normal - will be overwritten by synthetic attacks)
                y_node = torch.zeros(graph_data['num_nodes'], dtype=torch.long)
                
                processed_graph = {
                    'graph_id': graph_data['graph_id'],
                    'x': x,
                    'edge_index': edge_index,
                    'edge_attr': edge_attr,
                    'y_node': y_node,
                    'window_idx': enterprise_id,  # Use enterprise_id as window_idx
                    'file_source': graph_data['file_source'],
                    'num_nodes': graph_data['num_nodes'],
                    'num_edges': graph_data['num_edges'],
                    'statistics': graph_data['statistics'],
                    'enterprise_id': enterprise_id
                }
                
                processed_graphs.append(processed_graph)
                enterprise_id += 1
                
                # Stop if we've reached the maximum number of graphs
                if len(processed_graphs) >= args.max_graphs:
                    break
        
        # Stop processing files if we've reached the maximum
        if len(processed_graphs) >= args.max_graphs:
            break
    
    processing_time = time.time() - start_time
    logger.info(f"Processed {len(processed_graphs)} graphs in {processing_time:.2f} seconds")
    
    # Save processed data
    output_file = output_dir / "cisco_graphs_processed.pkl"
    logger.info(f"Saving processed graphs to {output_file}")
    
    with open(output_file, 'wb') as f:
        pickle.dump(processed_graphs, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Save metadata
    metadata = {
        'num_graphs': len(processed_graphs),
        'processing_time': processing_time,
        'source_files': [str(f) for f in graph_files[:len(processed_graphs)]],
        'node_feature_dim': processed_graphs[0]['x'].shape[1] if processed_graphs else 0,
        'edge_feature_dim': processed_graphs[0]['edge_attr'].shape[1] if processed_graphs and processed_graphs[0]['edge_attr'] is not None else 0,
        'graph_statistics': {
            'nodes': [g['num_nodes'] for g in processed_graphs],
            'edges': [g['num_edges'] for g in processed_graphs],
            'avg_nodes': np.mean([g['num_nodes'] for g in processed_graphs]),
            'avg_edges': np.mean([g['num_edges'] for g in processed_graphs]),
        }
    }
    
    metadata_file = output_dir / "metadata.json"
    with open(metadata_file, 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    
    logger.info("Preprocessing completed successfully!")
    logger.info(f"Processed {len(processed_graphs)} enterprise graphs")
    logger.info(f"Average nodes per graph: {metadata['graph_statistics']['avg_nodes']:.1f}")
    logger.info(f"Average edges per graph: {metadata['graph_statistics']['avg_edges']:.1f}")
    logger.info(f"Node feature dimension: {metadata['node_feature_dim']}")
    logger.info(f"Edge feature dimension: {metadata['edge_feature_dim']}")
    logger.info(f"Files saved:")
    logger.info(f"  - Graphs: {output_file}")
    logger.info(f"  - Metadata: {metadata_file}")
    
    # Print usage instructions
    logger.info("\n" + "="*60)
    logger.info("USAGE INSTRUCTIONS")
    logger.info("="*60)
    logger.info("To use the preprocessed data in training, run:")
    logger.info(f"python -m gatv2_ns3_ids.scripts.train_cisco_synthetic \\")
    logger.info(f"  --config configuration_study_configs/config_cisco_synthetic.yaml \\")
    logger.info(f"  --cisco_data_path {args.output_dir} \\")
    logger.info(f"  --output_dir outputs/cisco_synthetic \\")
    logger.info(f"  --epochs 50")


if __name__ == "__main__":
    main()
