"""
Real network intrusion detection datasets for GATv2-NS3 IDS.
Supports common datasets like CIC-IDS2017, NSL-KDD, UNSW-NB15.
"""

import os
import pandas as pd
import numpy as np
import torch
from typing import List, Tuple, Dict, Optional
import urllib.request
import zipfile
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
import networkx as nx

from ..utils.common import GraphData, ensure_dir, get_logger


class NSLKDDDatasetLoader:
    """
    Enhanced NSL-KDD Dataset Loader for Proposal 2: Curiosity Loop Integration
    
    Converts tabular NSL-KDD data to graph format suitable for GATv2 with
    Curiosity Loop feedback system. Supports multi-class attack classification
    and attention-aware graph construction.
    """

    def __init__(self, data_dir: str = "data/nsl_kdd"):
        self.data_dir = data_dir
        self.logger = get_logger("nsl_kdd_loader")
        ensure_dir(data_dir)

        # NSL-KDD configuration for Proposal 2
        self.nsl_kdd_config = {
            "train_url": "https://raw.githubusercontent.com/defcom17/NSL_KDD/master/KDDTrain%2B.txt",
            "test_url": "https://raw.githubusercontent.com/defcom17/NSL_KDD/master/KDDTest%2B.txt",
            "train_filename": "KDDTrain+.txt",
            "test_filename": "KDDTest+.txt",
            "columns": self._get_nsl_kdd_columns(),
            "attack_types": self._get_nsl_kdd_attack_types()
        }
        
        # Attack type to class mapping for multi-class classification
        self.attack_type_to_class = {
            'normal': 0,
            'dos': 1,
            'probe': 2,
            'r2l': 3,
            'u2r': 4
        }

    def _get_nsl_kdd_columns(self) -> List[str]:
        """Get NSL-KDD column names (includes difficulty level)."""
        return [
            'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
            'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
            'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
            'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
            'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
            'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate',
            'srv_diff_host_rate', 'dst_host_count', 'dst_host_srv_count',
            'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
            'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
            'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'label', 'difficulty_level'
        ]


    def _get_nsl_kdd_attack_types(self) -> Dict[str, str]:
        """Get NSL-KDD attack type mappings."""
        return {
            'normal': 'normal',
            'smurf': 'dos', 'neptune': 'dos', 'back': 'dos', 'teardrop': 'dos', 'pod': 'dos', 'land': 'dos',
            'satan': 'probe', 'ipsweep': 'probe', 'nmap': 'probe', 'portsweep': 'probe',
            'ftp_write': 'r2l', 'guess_passwd': 'r2l', 'imap': 'r2l', 'multihop': 'r2l',
            'phf': 'r2l', 'spy': 'r2l', 'warezclient': 'r2l', 'warezmaster': 'r2l',
            'buffer_overflow': 'u2r', 'loadmodule': 'u2r', 'perl': 'u2r', 'rootkit': 'u2r'
        }

    def download_dataset(self) -> bool:
        """Download the NSL-KDD dataset."""
        return self._download_nsl_kdd(self.nsl_kdd_config)

    def _download_nsl_kdd(self, config: Dict) -> bool:
        """Download NSL-KDD dataset (both train and test files)."""
        try:
            train_filepath = os.path.join(self.data_dir, config["train_filename"])
            test_filepath = os.path.join(self.data_dir, config["test_filename"])
            
            # Download training set
            if not os.path.exists(train_filepath):
                self.logger.info(f"Downloading NSL-KDD training set from {config['train_url']}")
                urllib.request.urlretrieve(config["train_url"], train_filepath)
                self.logger.info(f"Downloaded training set to {train_filepath}")
            
            # Download test set
            if not os.path.exists(test_filepath):
                self.logger.info(f"Downloading NSL-KDD test set from {config['test_url']}")
                urllib.request.urlretrieve(config["test_url"], test_filepath)
                self.logger.info(f"Downloaded test set to {test_filepath}")
            
            return True
        except Exception as e:
            self.logger.error(f"Failed to download NSL-KDD: {e}")
            return False


    def load_dataset(self, split: str = "train") -> Optional[pd.DataFrame]:
        """Load the NSL-KDD dataset into a pandas DataFrame.
        
        Args:
            split: Specify 'train' or 'test' for NSL-KDD dataset.
        """
        return self._load_nsl_kdd(split)

    def _load_nsl_kdd(self, split: str = "train") -> Optional[pd.DataFrame]:
        """Load NSL-KDD dataset."""
        config = self.nsl_kdd_config
        
        if split == "train":
            filepath = os.path.join(self.data_dir, config["train_filename"])
        elif split == "test":
            filepath = os.path.join(self.data_dir, config["test_filename"])
        else:
            self.logger.error(f"Invalid split: {split}. Use 'train' or 'test'")
            return None

        if not os.path.exists(filepath):
            self.logger.error(f"Dataset file not found: {filepath}")
            return None

        try:
            # Load the data (no header in the original file)
            df = pd.read_csv(filepath, header=None, names=config["columns"])
            
            # Clean label names (remove trailing periods if present)
            df['label'] = df['label'].str.replace('.', '', regex=False)
            
            self.logger.info(f"Loaded NSL-KDD {split} dataset with shape: {df.shape}")
            return df
        except Exception as e:
            self.logger.error(f"Failed to load NSL-KDD {split}: {e}")
            return None


    def preprocess_dataset(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str]]:
        """Preprocess the NSL-KDD dataset for model training."""
        return self._preprocess_nsl_kdd(df)

    def _preprocess_nsl_kdd(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str]]:
        """Preprocess NSL-KDD dataset."""
        # Create binary labels (normal vs attack) - labels are already cleaned
        df['is_attack'] = df['label'].apply(lambda x: 0 if x == 'normal' else 1)

        # Encode categorical features
        categorical_cols = ['protocol_type', 'service', 'flag']
        le_dict = {}

        for col in categorical_cols:
            le = LabelEncoder()
            df[col] = le.fit_transform(df[col])
            le_dict[col] = le

        # Select numeric features (exclude label, difficulty_level, and is_attack)
        numeric_cols = [col for col in df.columns if col not in ['label', 'difficulty_level', 'is_attack']]
        X = df[numeric_cols].values
        y = df['is_attack'].values

        # Scale features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)

        self.logger.info(f"Preprocessed NSL-KDD data: X shape {X_scaled.shape}, y shape {y.shape}")
        self.logger.info(f"Attack ratio: {y.mean():.3f}")
        self.logger.info(f"Label distribution: {df['label'].value_counts().head()}")

        return X_scaled, y, numeric_cols


    def load_train_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load and preprocess NSL-KDD training data for Proposal 2."""
        # Download if needed
        if not self._download_nsl_kdd(self.nsl_kdd_config):
            raise RuntimeError("Failed to download NSL-KDD dataset")
        
        # Load training data
        df = self._load_nsl_kdd("train")
        if df is None:
            raise RuntimeError("Failed to load NSL-KDD training data")
        
        # Preprocess with multi-class labels
        X, y_multiclass = self._preprocess_nsl_kdd_multiclass(df)
        
        return X, y_multiclass
    
    def load_test_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load and preprocess NSL-KDD test data for Proposal 2."""
        # Load test data
        df = self._load_nsl_kdd("test")
        if df is None:
            raise RuntimeError("Failed to load NSL-KDD test data")
        
        # Preprocess with multi-class labels
        X, y_multiclass = self._preprocess_nsl_kdd_multiclass(df)
        
        return X, y_multiclass
    
    def _preprocess_nsl_kdd_multiclass(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Preprocess NSL-KDD for multi-class classification (Proposal 2)."""
        # Map specific attacks to general categories
        attack_mapping = self._get_nsl_kdd_attack_types()
        df['attack_category'] = df['label'].map(attack_mapping)
        
        # Handle unknown attacks (map to 'dos' as default)
        df['attack_category'] = df['attack_category'].fillna('dos')
        
        # Convert to class indices
        df['class_label'] = df['attack_category'].map(self.attack_type_to_class)
        
        # Encode categorical features
        categorical_cols = ['protocol_type', 'service', 'flag']
        for col in categorical_cols:
            le = LabelEncoder()
            df[col] = le.fit_transform(df[col])
        
        # Select numeric features
        numeric_cols = [col for col in df.columns 
                       if col not in ['label', 'difficulty_level', 'attack_category', 'class_label']]
        X = df[numeric_cols].values
        y = df['class_label'].values
        
        # Scale features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        self.logger.info(f"Preprocessed NSL-KDD (multi-class): X shape {X_scaled.shape}, y shape {y.shape}")
        self.logger.info(f"Class distribution: {np.bincount(y)}")
        
        return X_scaled, y
    
    def convert_to_graphs(self, X: np.ndarray, y: np.ndarray, 
                         graph_construction_method: str = "knn",
                         k: int = 10, 
                         feature_similarity_threshold: float = 0.7) -> List[GraphData]:
        """
        Convert NSL-KDD tabular data to graphs for Proposal 2.
        
        Args:
            X: Feature matrix
            y: Multi-class labels
            graph_construction_method: Method for graph construction ("knn", "similarity", "random")
            k: Number of neighbors for KNN graph
            feature_similarity_threshold: Threshold for similarity-based connections
        """
        if graph_construction_method == "knn":
            return self._create_knn_graphs(X, y, k)
        elif graph_construction_method == "similarity":
            return self._create_similarity_graphs(X, y, feature_similarity_threshold)
        elif graph_construction_method == "random":
            return self._create_random_graphs(X, y)
        else:
            raise ValueError(f"Unknown graph construction method: {graph_construction_method}")
    
    def _create_knn_graphs(self, X: np.ndarray, y: np.ndarray, k: int = 10) -> List[GraphData]:
        """Create graphs using k-nearest neighbors approach."""
        from sklearn.neighbors import NearestNeighbors
        
        graphs = []
        samples_per_graph = 200  # Reasonable size for Curiosity Loop
        num_graphs = len(X) // samples_per_graph
        
        for i in range(num_graphs):
            start_idx = i * samples_per_graph
            end_idx = min((i + 1) * samples_per_graph, len(X))
            
            graph_X = X[start_idx:end_idx]
            graph_y = y[start_idx:end_idx]
            
            if len(graph_X) < k:
                continue
            
            # Build KNN graph
            nbrs = NearestNeighbors(n_neighbors=min(k, len(graph_X)), algorithm='ball_tree')
            nbrs.fit(graph_X)
            distances, indices = nbrs.kneighbors(graph_X)
            
            # Create edge list
            edges = []
            for node_idx, neighbors in enumerate(indices):
                for neighbor_idx in neighbors[1:]:  # Skip self
                    edges.append([node_idx, neighbor_idx])
            
            if not edges:
                continue
            
            # Convert to PyTorch format
            edge_index = torch.tensor(edges, dtype=torch.long).t()
            # Make undirected
            edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            
            # Create edge attributes (distances)
            edge_distances = []
            for node_idx, (dists, neighbors) in enumerate(zip(distances, indices)):
                for dist, neighbor_idx in zip(dists[1:], neighbors[1:]):
                    edge_distances.append([dist, 1.0, 0.5])  # [distance, weight, type]
            
            edge_attr = torch.tensor(edge_distances * 2, dtype=torch.float)  # *2 for undirected
            
            # Node features and labels
            node_features = torch.tensor(graph_X, dtype=torch.float)
            node_labels = torch.tensor(graph_y, dtype=torch.long)
            
            graph_data = GraphData(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y_node=node_labels,
                graph_id=f"nsl_kdd_knn_{i}",
                window_idx=i
            )
            
            graphs.append(graph_data)
        
        self.logger.info(f"Created {len(graphs)} KNN graphs from NSL-KDD data")
        return graphs
    
    def _create_similarity_graphs(self, X: np.ndarray, y: np.ndarray, 
                                threshold: float = 0.7) -> List[GraphData]:
        """Create graphs based on feature similarity."""
        from sklearn.metrics.pairwise import cosine_similarity
        
        graphs = []
        samples_per_graph = 150  # Smaller for similarity-based approach
        num_graphs = len(X) // samples_per_graph
        
        for i in range(num_graphs):
            start_idx = i * samples_per_graph
            end_idx = min((i + 1) * samples_per_graph, len(X))
            
            graph_X = X[start_idx:end_idx]
            graph_y = y[start_idx:end_idx]
            
            # Calculate similarity matrix
            similarity_matrix = cosine_similarity(graph_X)
            
            # Create edges based on similarity threshold
            edges = []
            edge_weights = []
            
            for i in range(len(graph_X)):
                for j in range(i + 1, len(graph_X)):
                    if similarity_matrix[i, j] > threshold:
                        edges.append([i, j])
                        edge_weights.append(similarity_matrix[i, j])
            
            if not edges:
                continue
            
            # Convert to PyTorch format
            edge_index = torch.tensor(edges, dtype=torch.long).t()
            # Make undirected
            edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            
            # Edge attributes
            edge_attr_list = [[w, w, 1.0] for w in edge_weights]
            edge_attr = torch.tensor(edge_attr_list * 2, dtype=torch.float)  # *2 for undirected
            
            # Node features and labels
            node_features = torch.tensor(graph_X, dtype=torch.float)
            node_labels = torch.tensor(graph_y, dtype=torch.long)
            
            graph_data = GraphData(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y_node=node_labels,
                graph_id=f"nsl_kdd_sim_{i}",
                window_idx=i
            )
            
            graphs.append(graph_data)
        
        self.logger.info(f"Created {len(graphs)} similarity graphs from NSL-KDD data")
        return graphs
    
    def _create_random_graphs(self, X: np.ndarray, y: np.ndarray) -> List[GraphData]:
        """Create graphs with random connections (baseline)."""
        graphs = []
        samples_per_graph = 100
        num_graphs = len(X) // samples_per_graph
        
        for i in range(num_graphs):
            start_idx = i * samples_per_graph
            end_idx = min((i + 1) * samples_per_graph, len(X))
            
            graph_X = X[start_idx:end_idx]
            graph_y = y[start_idx:end_idx]
            
            num_nodes = len(graph_X)
            
            # Create random graph
            G = nx.erdos_renyi_graph(num_nodes, p=0.1, seed=42 + i)
            
            edges = list(G.edges())
            if not edges:
                continue
            
            # Convert to PyTorch format
            edge_index = torch.tensor(edges, dtype=torch.long).t()
            # Make undirected
            edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            
            # Simple edge attributes
            num_edges = edge_index.shape[1]
            edge_attr = torch.ones(num_edges, 3)
            
            # Node features and labels
            node_features = torch.tensor(graph_X, dtype=torch.float)
            node_labels = torch.tensor(graph_y, dtype=torch.long)
            
            graph_data = GraphData(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y_node=node_labels,
                graph_id=f"nsl_kdd_random_{i}",
                window_idx=i
            )
            
            graphs.append(graph_data)
        
        self.logger.info(f"Created {len(graphs)} random graphs from NSL-KDD data")
        return graphs

    def create_graph_dataset(self, X: np.ndarray, y: np.ndarray,
                           num_graphs: int = 50, nodes_per_graph: int = 100,
                           seed: int = 42) -> List[GraphData]:
        """Convert tabular data to graph format for GATv2."""
        np.random.seed(seed)
        graphs = []

        # Shuffle the data to ensure balanced distribution across graphs
        indices = np.arange(len(X))
        np.random.shuffle(indices)
        X_shuffled = X[indices]
        y_shuffled = y[indices]

        # Calculate samples per graph
        samples_per_graph = len(X_shuffled) // num_graphs
        feature_dim = X_shuffled.shape[1]

        for i in range(num_graphs):
            start_idx = i * samples_per_graph
            end_idx = min((i + 1) * samples_per_graph, len(X_shuffled))

            graph_X = X_shuffled[start_idx:end_idx]
            graph_y = y_shuffled[start_idx:end_idx]

            # Create random graph structure
            num_nodes = min(len(graph_X), nodes_per_graph)
            graph_X = graph_X[:num_nodes]
            graph_y = graph_y[:num_nodes]

            # Create Erdős–Rényi graph
            p = min(0.1, 10.0 / num_nodes)  # Adjust connectivity
            G = nx.erdos_renyi_graph(num_nodes, p, seed=seed + i)

            # Convert to PyTorch format
            edges = list(G.edges())
            if edges:
                edge_index = torch.tensor(edges, dtype=torch.long).t()
                # Add reverse edges
                edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)
            else:
                edge_index = torch.zeros(2, 0, dtype=torch.long)

            # Create edge attributes (simple connectivity)
            num_edges = edge_index.shape[1]
            if num_edges > 0:
                edge_attr = torch.ones(num_edges, 3)  # Simple edge features
            else:
                edge_attr = torch.zeros(0, 3)

            # Convert node features
            node_features = torch.tensor(graph_X, dtype=torch.float)
            node_labels = torch.tensor(graph_y, dtype=torch.long)

            graph_data = GraphData(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y_node=node_labels,
                graph_id=f"nsl_kdd_{i}",
                window_idx=i
            )

            graphs.append(graph_data)

        self.logger.info(f"Created {len(graphs)} graphs from real dataset")
        return graphs


def create_nsl_kdd_dataset(num_graphs: int = 50,
                       nodes_per_graph: int = 100, save_path: str = "data/nsl_kdd_graphs.pkl",
                       split: str = "train", seed: int = 42) -> List[GraphData]:
    """Create and save a graph dataset from NSL-KDD network data.
    
    Args:
        num_graphs: Number of graphs to create
        nodes_per_graph: Number of nodes per graph
        save_path: Path to save the processed graphs
        split: For NSL-KDD, specify 'train' or 'test'
        seed: Random seed
    """
    loader = NSLKDDDatasetLoader()

    # Try to download NSL-KDD dataset (if not already available)
    try:
        download_success = loader._download_nsl_kdd(loader.nsl_kdd_config)
    except Exception as e:
        loader.logger.warning(f"Download attempt failed: {e}, checking if data exists locally...")
        download_success = True  # Continue if files exist locally

    # Load and preprocess
    df = loader._load_nsl_kdd(split)
    if df is None:
        loader.logger.error("Failed to load NSL-KDD dataset")
        raise RuntimeError(
            "Failed to process NSL-KDD dataset. "
            "Please verify the dataset file is valid and accessible."
        )

    X, y = loader._preprocess_nsl_kdd_multiclass(df)

    # Create graph dataset
    graphs = loader.create_graph_dataset(X, y, num_graphs, nodes_per_graph, seed)

    # Save to disk
    ensure_dir(os.path.dirname(save_path))
    import pickle
    with open(save_path, "wb") as f:
        pickle.dump(graphs, f)

    loader.logger.info(f"Saved {len(graphs)} graphs to {save_path}")
    return graphs




def load_nsl_kdd_dataset(path: str) -> List[GraphData]:
    """Load preprocessed NSL-KDD dataset."""
    import pickle
    with open(path, "rb") as f:
        graphs = pickle.load(f)
    return graphs


if __name__ == "__main__":
    # Example usage
    logger = get_logger("nsl_kdd_main")
    
    # Test NSL-KDD dataset loading
    logger.info("Testing NSL-KDD dataset...")
    loader = NSLKDDDatasetLoader()
    
    try:
        # Test train data loading
        train_X, train_y = loader.load_train_data()
        logger.info(f"NSL-KDD Train data: {train_X.shape[0]} samples, {train_X.shape[1]} features")
        logger.info(f"Train class distribution: {np.bincount(train_y)}")
        
        # Test test data loading
        test_X, test_y = loader.load_test_data()
        logger.info(f"NSL-KDD Test data: {test_X.shape[0]} samples, {test_X.shape[1]} features")
        logger.info(f"Test class distribution: {np.bincount(test_y)}")
        
        # Create graph dataset from training set
        logger.info("Creating NSL-KDD graph dataset...")
        graphs = create_nsl_kdd_dataset(
            num_graphs=20,
            nodes_per_graph=80,
            save_path="data/nsl_kdd_train_graphs.pkl",
            split="train"
        )
        
        logger.info(f"Created {len(graphs)} graphs from NSL-KDD training set")
        
    except Exception as e:
        logger.error(f"Failed to test NSL-KDD: {e}")
