#!/usr/bin/env python3
"""
Unified Baseline Training Script for GATv2-NS3 Hybrid IDS

This script trains various baseline models (GIN, GraphSAGE, MLP, RandomForest, 
LogisticRegression, XGBoost) on both NSL-KDD and Cisco datasets for comparison
with the main GATv2-NS3 hybrid approach.
"""

import argparse
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple, Union
import json
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import warnings

try:
    from tqdm import tqdm
    TQDM_AVAILABLE = True
except ImportError:
    TQDM_AVAILABLE = False
    print("Warning: tqdm not available. Progress bars will be disabled.")

from ..data.nsl_kdd import NSLKDDDatasetLoader
from ..data.cisco_dataset import CiscoDatasetLoader
from ..data.attack_pattern_generator import inject_realistic_attacks_into_graphs
from ..models.gin_ids import GIN_IDS
from ..models.graphsage_ids import GraphSAGE_IDS
from ..models.mlp_ids import MLP_IDS
from ..models.random_forest_ids import RandomForest_IDS
from ..models.logistic_regression_ids import LogisticRegression_IDS
from ..models.xgboost_ids import XGBoostIDS
from ..utils.common import (
    GraphData, get_logger, set_seed, ensure_dir, to_device
)


def load_config(config_path: str) -> dict:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def create_nsl_kdd_graphs(dataset_loader: NSLKDDDatasetLoader, 
                         config: dict) -> Tuple[List[GraphData], List[GraphData], List[GraphData]]:
    """Create graph representations from NSL-KDD dataset."""
    logger = get_logger("nsl_kdd_graphs")
    
    # Load NSL-KDD data
    train_data, train_labels = dataset_loader.load_train_data()
    test_data, test_labels = dataset_loader.load_test_data()
    
    logger.info(f"Loaded NSL-KDD: {len(train_data)} train, {len(test_data)} test samples")
    
    # Convert to graph format
    train_graphs = dataset_loader.convert_to_graphs(
        train_data, train_labels, 
        graph_construction_method="knn",
        k=10,
        feature_similarity_threshold=0.7
    )
    
    test_graphs = dataset_loader.convert_to_graphs(
        test_data, test_labels,
        graph_construction_method="knn",
        k=10,
        feature_similarity_threshold=0.7
    )
    
    # Split train into train/val
    val_split = config.get('dataset', {}).get('val_split', 0.15)
    val_size = int(len(train_graphs) * val_split)
    
    val_graphs = train_graphs[:val_size]
    train_graphs = train_graphs[val_size:]
    
    logger.info(f"Created graphs: {len(train_graphs)} train, {len(val_graphs)} val, {len(test_graphs)} test")
    
    return train_graphs, val_graphs, test_graphs


def create_cisco_dataset_with_synthetic_attacks(config: Dict, attack_ratio: float, 
                                               cisco_data_path: str, logger):
    """Create Cisco dataset with synthetic attack injection."""
    
    logger.info("Creating Cisco dataset with synthetic attack injection...")
    
    # Load Cisco dataset from pickle file
    if cisco_data_path.endswith('.pkl') or Path(cisco_data_path).is_file():
        pickle_path = cisco_data_path
    else:
        pickle_path = Path(cisco_data_path) / "cisco_graphs_small.pkl"
        if not pickle_path.exists():
            pickle_path = Path(cisco_data_path) / "cisco_graphs_processed.pkl"
    
    if not pickle_path.exists():
        raise RuntimeError(f"Cisco graphs pickle file not found: {pickle_path}")
    
    logger.info(f"Loading Cisco graphs from: {pickle_path}")
    with open(pickle_path, 'rb') as f:
        enterprise_graphs = pickle.load(f)
    
    if not enterprise_graphs:
        raise RuntimeError("Failed to load Cisco enterprise graphs")
    
    logger.info(f"Loaded {len(enterprise_graphs)} enterprise graphs from Cisco dataset")
    
    # Split enterprises
    n_enterprises = len(enterprise_graphs)
    
    if n_enterprises < 3:
        train_graphs = enterprise_graphs
        val_graphs = enterprise_graphs[:1] if n_enterprises > 0 else []
        test_graphs = enterprise_graphs
        logger.warning(f"Small dataset ({n_enterprises} graphs). Using overlapping splits.")
    else:
        train_end = max(1, int(0.64 * n_enterprises))
        val_end = train_end + max(1, int(0.18 * n_enterprises))
        
        train_graphs = enterprise_graphs[:train_end]
        val_graphs = enterprise_graphs[train_end:val_end]
        test_graphs = enterprise_graphs[val_end:] if val_end < n_enterprises else enterprise_graphs[-1:]
    
    logger.info(f"Enterprise split: {len(train_graphs)} train, {len(val_graphs)} val, {len(test_graphs)} test")
    
    # Apply synthetic attack injection
    train_graphs_with_attacks = inject_realistic_attacks_into_graphs(
        train_graphs, attack_ratio=attack_ratio, seed=config.get("seed", 42)
    )
    val_graphs_with_attacks = inject_realistic_attacks_into_graphs(
        val_graphs, attack_ratio=attack_ratio, seed=config.get("seed", 42) + 1000
    )
    test_graphs_with_attacks = inject_realistic_attacks_into_graphs(
        test_graphs, attack_ratio=attack_ratio, seed=config.get("seed", 42) + 2000
    )
    
    return train_graphs_with_attacks, val_graphs_with_attacks, test_graphs_with_attacks


def graphs_to_tabular(graphs: List[GraphData]) -> Tuple[np.ndarray, np.ndarray]:
    """Convert graph data to tabular format for sklearn models."""
    X_list = []
    y_list = []
    
    for graph in graphs:
        # Use node features as input
        X_list.append(graph.x.cpu().numpy())
        y_list.append(graph.y_node.cpu().numpy())
    
    X = np.vstack(X_list)
    y = np.hstack(y_list)
    
    return X, y


def create_model(model_name: str, config: dict, in_dim_node: int, in_dim_edge: int = 0, 
                num_classes: int = 2, device: torch.device = None):
    """Create a model instance based on the model name."""
    
    model_config = config.get('model', {})
    hidden = model_config.get('hidden', 64)
    layers = model_config.get('layers', 2)
    dropout = model_config.get('dropout', 0.1)
    
    if model_name.lower() in ['gin', 'gin_ids']:
        model = GIN_IDS(
            in_dim_node=in_dim_node,
            hidden=hidden,
            layers=layers,
            dropout=dropout,
            num_classes=num_classes
        )
        if device:
            model = model.to(device)
        return model
        
    elif model_name.lower() in ['graphsage', 'sage', 'graphsage_ids']:
        model = GraphSAGE_IDS(
            in_dim_node=in_dim_node,
            hidden=hidden,
            layers=layers,
            dropout=dropout,
            num_classes=num_classes
        )
        if device:
            model = model.to(device)
        return model
        
    elif model_name.lower() in ['mlp', 'mlp_ids']:
        model = MLP_IDS(
            in_dim=in_dim_node,
            hidden=hidden,
            num_classes=num_classes,
            dropout=dropout
        )
        if device:
            model = model.to(device)
        return model
        
    elif model_name.lower() in ['randomforest', 'rf', 'random_forest']:
        return RandomForest_IDS(
            random_state=42,
            n_estimators=model_config.get('n_estimators', 100),
            class_weight='balanced'
        )
            
    elif model_name.lower() in ['logisticregression', 'lr', 'logistic']:
        return LogisticRegression_IDS(
            random_state=42,
            max_iter=1000,
            class_weight='balanced'
        )
            
    elif model_name.lower() in ['xgboost', 'xgb']:
        return XGBoostIDS(
            random_state=42,
            n_estimators=model_config.get('n_estimators', 100),
            learning_rate=model_config.get('learning_rate', 0.01)
        )
    
    else:
        raise ValueError(f"Unknown model name: {model_name}")


def train_pytorch_model(model: nn.Module, train_graphs: List[GraphData], 
                       val_graphs: List[GraphData], config: dict, 
                       device: torch.device, logger, dataset_type: str = 'cisco') -> Dict:
    """Train a PyTorch-based model."""
    
    # Setup optimizer and criterion
    optimizer = optim.Adam(
        model.parameters(),
        lr=float(config['training']['learning_rate']),
        weight_decay=float(config['training']['weight_decay'])
    )
    
    # Handle class imbalance
    all_labels = torch.cat([g.y_node for g in train_graphs])
    unique_classes = torch.unique(all_labels)
    num_classes = len(unique_classes)
    
    if num_classes > 1:
        # Calculate class weights for all classes
        class_weights = []
        total_samples = len(all_labels)
        
        for class_idx in range(num_classes):
            class_count = (all_labels == class_idx).sum().item()
            if class_count > 0:
                weight = total_samples / (num_classes * class_count)
                class_weights.append(weight)
            else:
                class_weights.append(1.0)
        
        class_weights = torch.tensor(class_weights, device=device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()
    
    epochs = config['training']['epochs']
    best_f1 = 0.0
    patience = config['training'].get('patience', 5)
    patience_counter = 0
    
    training_history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_accuracy': [],
        'val_f1': []
    }
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        iterator = enumerate(train_graphs)
        if TQDM_AVAILABLE:
            iterator = tqdm(iterator, total=len(train_graphs), desc=f"Epoch {epoch+1}")
        
        for i, graph in iterator:
            graph = to_device(graph, device)
            
            optimizer.zero_grad()
            output = model(graph)
            logits = output["node_logits"]
            
            loss = criterion(logits, graph.y_node.long())
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = torch.max(logits, 1)
            total_correct += (predicted == graph.y_node.long()).sum().item()
            total_samples += graph.y_node.size(0)
        
        train_accuracy = total_correct / total_samples
        avg_loss = total_loss / len(train_graphs)
        
        # Validation
        val_metrics = evaluate_pytorch_model(model, val_graphs, device, dataset_type)
        
        # Logging
        logger.info(f"Epoch {epoch+1}/{epochs}: "
                   f"Loss={avg_loss:.4f}, "
                   f"Train_Acc={train_accuracy:.4f}, "
                   f"Val_Acc={val_metrics['accuracy']:.4f}, "
                   f"Val_F1={val_metrics['f1']:.4f}")
        
        # Save history
        training_history['train_loss'].append(avg_loss)
        training_history['train_accuracy'].append(train_accuracy)
        training_history['val_accuracy'].append(val_metrics['accuracy'])
        training_history['val_f1'].append(val_metrics['f1'])
        
        # Early stopping
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            logger.info(f"Early stopping at epoch {epoch+1}")
            break
    
    return {
        'training_history': training_history,
        'best_val_f1': best_f1,
        'total_epochs': epoch + 1
    }


def train_sklearn_model(model, train_graphs: List[GraphData], 
                       val_graphs: List[GraphData], config: dict, logger, dataset_type: str = 'cisco') -> Dict:
    """Train a scikit-learn based model."""
    
    # Convert graphs to tabular format
    X_train, y_train = graphs_to_tabular(train_graphs)
    X_val, y_val = graphs_to_tabular(val_graphs)
    
    logger.info(f"Training sklearn model on {X_train.shape[0]} samples with {X_train.shape[1]} features")
    
    # Train model
    start_time = time.time()
    model.fit(X_train, y_train)
    train_time = time.time() - start_time
    
    # Evaluate on validation set
    val_metrics = evaluate_sklearn_model(model, X_val, y_val, dataset_type)
    
    logger.info(f"Training completed in {train_time:.2f}s: "
               f"Val_Acc={val_metrics['accuracy']:.4f}, "
               f"Val_F1={val_metrics['f1']:.4f}")
    
    return {
        'training_time': train_time,
        'val_metrics': val_metrics
    }


def evaluate_pytorch_model(model: nn.Module, graphs: List[GraphData], 
                          device: torch.device, dataset_type: str = 'cisco') -> Dict:
    """Evaluate PyTorch model."""
    model.eval()
    
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for graph in graphs:
            graph = to_device(graph, device)
            output = model(graph)
            logits = output["node_logits"]
            
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_predictions.extend(preds.cpu().numpy())
            all_labels.extend(graph.y_node.cpu().numpy())
    
    return calculate_metrics(all_labels, all_predictions, dataset_type)


def evaluate_sklearn_model(model, X: np.ndarray, y: np.ndarray, dataset_type: str = 'cisco') -> Dict:
    """Evaluate sklearn model."""
    predictions = model.predict(X)
    return calculate_metrics(y, predictions, dataset_type)


def calculate_metrics(y_true, y_pred, dataset_type='cisco') -> Dict:
    """Calculate evaluation metrics with optional per-class breakdown for NSL-KDD."""
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        
        accuracy = accuracy_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
        recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
        f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
        
        try:
            if len(set(y_true)) == 2 and len(set(y_pred)) == 2:
                auc = roc_auc_score(y_true, y_pred)
            else:
                auc = accuracy  # Fallback
        except:
            auc = accuracy
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc
    }
    
    # Add per-class metrics for NSL-KDD (multi-class)
    if dataset_type == 'nsl_kdd':
        attack_types = ['normal', 'dos', 'probe', 'r2l', 'u2r']
        num_classes = len(attack_types)
        
        # Calculate per-class accuracy
        class_accuracies = {}
        for class_idx in range(num_classes):
            class_mask = (np.array(y_true) == class_idx)
            if class_mask.sum() > 0:  # Only if class exists in true labels
                class_pred_mask = (np.array(y_pred) == class_idx)
                class_correct = ((np.array(y_true) == class_idx) & (np.array(y_pred) == class_idx)).sum()
                class_total = class_mask.sum()
                class_accuracies[attack_types[class_idx]] = class_correct / class_total
            else:
                class_accuracies[attack_types[class_idx]] = 0.0
        
        metrics['class_accuracies'] = class_accuracies
        
        # Add per-class precision, recall, F1
        try:
            class_precision = precision_score(y_true, y_pred, average=None, zero_division=0, labels=range(num_classes))
            class_recall = recall_score(y_true, y_pred, average=None, zero_division=0, labels=range(num_classes))
            class_f1 = f1_score(y_true, y_pred, average=None, zero_division=0, labels=range(num_classes))
            
            metrics['class_precision'] = {attack_types[i]: class_precision[i] for i in range(len(class_precision))}
            metrics['class_recall'] = {attack_types[i]: class_recall[i] for i in range(len(class_recall))}
            metrics['class_f1'] = {attack_types[i]: class_f1[i] for i in range(len(class_f1))}
        except:
            pass  # Skip if calculation fails
    
    return metrics


def main():
    parser = argparse.ArgumentParser(description='Train baseline models for comparison')
    parser.add_argument('--config', type=str, required=True,
                       help='Path to configuration file')
    parser.add_argument('--model', type=str, required=True,
                       choices=['gin', 'graphsage', 'mlp', 'randomforest', 'logistic', 'xgboost'],
                       help='Model to train')
    parser.add_argument('--dataset', type=str, required=True,
                       choices=['nsl_kdd', 'cisco'],
                       help='Dataset to use')
    parser.add_argument('--nsl_kdd_path', type=str, default='data/nsl_kdd',
                       help='Path to NSL-KDD dataset')
    parser.add_argument('--cisco_data_path', type=str, default='data/cisco_small',
                       help='Path to Cisco dataset')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Output directory for results')
    parser.add_argument('--epochs', type=int, default=None,
                       help='Number of epochs (overrides config)')
    parser.add_argument('--attack_ratio', type=float, default=0.10,
                       help='Attack ratio for Cisco dataset')
    
    args = parser.parse_args()
    
    # Load configuration
    config = load_config(args.config)
    
    # Override epochs if specified
    if args.epochs is not None:
        config['training']['epochs'] = args.epochs
    
    # Setup
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger = get_logger(f"{args.model}_{args.dataset}_train")
    
    logger.info(f"🎯 Training {args.model.upper()} on {args.dataset.upper()}")
    logger.info(f"Device: {device}")
    logger.info(f"Output: {args.output_dir}")
    
    # Create output directory
    ensure_dir(args.output_dir)
    
    # Load dataset
    if args.dataset == 'nsl_kdd':
        try:
            dataset_loader = NSLKDDDatasetLoader(args.nsl_kdd_path)
            train_graphs, val_graphs, test_graphs = create_nsl_kdd_graphs(dataset_loader, config)
        except Exception as e:
            logger.error(f"Failed to load NSL-KDD dataset: {e}")
            return
    else:  # cisco
        try:
            train_graphs, val_graphs, test_graphs = create_cisco_dataset_with_synthetic_attacks(
                config, args.attack_ratio, args.cisco_data_path, logger
            )
        except Exception as e:
            logger.error(f"Failed to load Cisco dataset: {e}")
            return
    
    # Get dimensions
    sample_graph = train_graphs[0]
    in_dim_node = sample_graph.x.shape[1]
    in_dim_edge = sample_graph.edge_attr.shape[1] if sample_graph.edge_attr is not None else 0
    
    # Determine number of classes
    if args.dataset == 'nsl_kdd':
        num_classes = len(config.get('dataset', {}).get('attack_types', ['normal', 'dos', 'probe', 'r2l', 'u2r']))
    else:
        num_classes = 2  # Binary for Cisco
    
    logger.info(f"Input dimensions: node={in_dim_node}, edge={in_dim_edge}, classes={num_classes}")
    
    # Create model
    model = create_model(args.model, config, in_dim_node, in_dim_edge, num_classes, device)
    logger.info(f"Created {args.model} model")
    
    # Train model
    start_time = time.time()
    
    if args.model in ['randomforest', 'logistic', 'xgboost']:
        # Sklearn models
        train_results = train_sklearn_model(model, train_graphs, val_graphs, config, logger, args.dataset)
        
        # Final evaluation on test set
        X_test, y_test = graphs_to_tabular(test_graphs)
        test_metrics = evaluate_sklearn_model(model, X_test, y_test, args.dataset)
        
    else:
        # PyTorch models
        train_results = train_pytorch_model(model, train_graphs, val_graphs, config, device, logger, args.dataset)
        
        # Final evaluation on test set
        test_metrics = evaluate_pytorch_model(model, test_graphs, device, args.dataset)
    
    total_time = time.time() - start_time
    
    # Log final results
    logger.info("Final Test Results:")
    logger.info(f"  Accuracy: {test_metrics['accuracy']:.4f}")
    logger.info(f"  F1 Score: {test_metrics['f1']:.4f}")
    logger.info(f"  Precision: {test_metrics['precision']:.4f}")
    logger.info(f"  Recall: {test_metrics['recall']:.4f}")
    logger.info(f"  AUC: {test_metrics['auc']:.4f}")
    logger.info(f"  Total Time: {total_time:.2f}s")
    
    # Log per-class results for NSL-KDD
    if args.dataset == 'nsl_kdd' and 'class_accuracies' in test_metrics:
        logger.info("Per-class Test Results:")
        for attack_type, acc in test_metrics['class_accuracies'].items():
            logger.info(f"  {attack_type}: {acc:.4f}")
        
        if 'class_f1' in test_metrics:
            logger.info("Per-class F1 Scores:")
            for attack_type, f1 in test_metrics['class_f1'].items():
                logger.info(f"  {attack_type}: {f1:.4f}")
    
    # Save results
    results = {
        'model': args.model,
        'dataset': args.dataset,
        'config': config,
        'train_results': train_results,
        'test_metrics': test_metrics,
        'total_time': total_time,
        'args': vars(args)
    }
    
    results_path = Path(args.output_dir) / 'results.json'
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    # Save simple report
    report_path = Path(args.output_dir) / 'evaluation_report.txt'
    with open(report_path, 'w') as f:
        f.write(f"{args.model.upper()} on {args.dataset.upper()} - Evaluation Report\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Model: {args.model}\n")
        f.write(f"Dataset: {args.dataset}\n")
        f.write(f"Accuracy: {test_metrics['accuracy']:.4f}\n")
        f.write(f"F1 Score: {test_metrics['f1']:.4f}\n")
        f.write(f"Precision: {test_metrics['precision']:.4f}\n")
        f.write(f"Recall: {test_metrics['recall']:.4f}\n")
        f.write(f"AUC: {test_metrics['auc']:.4f}\n")
        f.write(f"Training Time: {total_time:.2f}s\n")
        
        # Add per-class results for NSL-KDD
        if args.dataset == 'nsl_kdd' and 'class_accuracies' in test_metrics:
            f.write("\nPer-class Accuracy:\n")
            f.write("-" * 20 + "\n")
            for attack_type, acc in test_metrics['class_accuracies'].items():
                f.write(f"{attack_type}: {acc:.4f}\n")
            
            if 'class_f1' in test_metrics:
                f.write("\nPer-class F1 Score:\n")
                f.write("-" * 20 + "\n")
                for attack_type, f1 in test_metrics['class_f1'].items():
                    f.write(f"{attack_type}: {f1:.4f}\n")
    
    logger.info(f"Training completed! Results saved to {args.output_dir}")


if __name__ == "__main__":
    main()
