#!/usr/bin/env python3
"""
NSL-KDD Training Script with Proposal 2: Self-Focusing Simulations (Curiosity Loop)

This script trains the GATv2-NS3 Hybrid IDS on the NSL-KDD dataset using the
Curiosity Loop feedback system that dynamically adjusts simulation fidelity
based on attention uncertainty.
"""

import argparse
import yaml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple
import json

try:
    from tqdm import tqdm
    TQDM_AVAILABLE = True
except ImportError:
    TQDM_AVAILABLE = False
    print("Warning: tqdm not available. Progress bars will be disabled.")

from ..data.nsl_kdd import NSLKDDDatasetLoader
from ..models.gatv2_ids import GATv2IDS
from ..simulation.curiosity_loop import CuriosityLoopFeedback
from ..simulation.ns3_client import NS3Client
from ..simulation.sim_cache import SimCache
from ..utils.common import (
    GraphData, get_logger, set_seed, ensure_dir, to_device
)


def load_config(config_path: str) -> dict:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def create_nsl_kdd_graphs(dataset_loader: NSLKDDDatasetLoader, 
                         config: dict) -> Tuple[List[GraphData], List[GraphData], List[GraphData]]:
    """
    Create graph representations from NSL-KDD dataset.
    
    NSL-KDD is originally tabular data, so we need to convert it to graphs
    for use with GATv2. We create graphs based on feature similarity and
    attack patterns.
    """
    logger = get_logger("nsl_kdd_graphs")
    
    # Load NSL-KDD data
    train_data, train_labels = dataset_loader.load_train_data()
    test_data, test_labels = dataset_loader.load_test_data()
    
    logger.info(f"Loaded NSL-KDD: {len(train_data)} train, {len(test_data)} test samples")
    
    # Convert to graph format
    train_graphs = dataset_loader.convert_to_graphs(
        train_data, train_labels, 
        graph_construction_method="knn",  # Use k-nearest neighbors
        k=10,  # Connect each sample to 10 nearest neighbors
        feature_similarity_threshold=0.7
    )
    
    test_graphs = dataset_loader.convert_to_graphs(
        test_data, test_labels,
        graph_construction_method="knn",
        k=10,
        feature_similarity_threshold=0.7
    )
    
    # Split train into train/val
    val_split = config.get('dataset', {}).get('val_split', 0.15)
    val_size = int(len(train_graphs) * val_split)
    
    val_graphs = train_graphs[:val_size]
    train_graphs = train_graphs[val_size:]
    
    logger.info(f"Created graphs: {len(train_graphs)} train, {len(val_graphs)} val, {len(test_graphs)} test")
    
    return train_graphs, val_graphs, test_graphs


def train_epoch_with_curiosity_loop(model: nn.Module, 
                                   graphs: List[GraphData],
                                   optimizer: optim.Optimizer,
                                   device: torch.device,
                                   curiosity_feedback: Optional[CuriosityLoopFeedback],
                                   config: dict,
                                   epoch: int) -> Dict[str, float]:
    """
    Train one epoch with Curiosity Loop feedback (Proposal 2).
    
    Key features:
    1. Adaptive fidelity based on attention uncertainty
    2. Iterative refinement until convergence
    3. Smart resource allocation
    4. Multi-class attack type awareness
    """
    model.train()
    
    total_loss = 0.0
    total_cls_loss = 0.0
    total_fid_loss = 0.0
    total_sim_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Curiosity loop metrics
    total_curiosity_iterations = 0
    total_convergence_count = 0
    total_resource_cost = 0.0
    
    # Loss weights from config
    lambda_fid = config.get('training', {}).get('lambda_fid', 0.12)
    lambda_sim = config.get('training', {}).get('lambda_sim', 0.08)
    
    # Multi-class loss function for NSL-KDD
    num_classes = len(config.get('dataset', {}).get('attack_types', ['normal', 'dos', 'probe', 'r2l', 'u2r']))
    criterion = nn.CrossEntropyLoss()
    
    logger = get_logger("train_epoch")
    
    # Progress bar
    iterator = enumerate(graphs)
    if TQDM_AVAILABLE:
        iterator = tqdm(iterator, total=len(graphs), desc=f"Epoch {epoch}")
    
    for i, graph in iterator:
        graph = to_device(graph, device)
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(graph)
        node_logits = output["node_logits"]
        edge_attn = output.get("edge_attn", None)
        
        # Classification loss (multi-class for NSL-KDD)
        cls_loss = criterion(node_logits, graph.y_node.long())
        
        # Curiosity Loop feedback losses
        fid_loss = torch.tensor(0.0, device=device)
        sim_loss = torch.tensor(0.0, device=device)
        
        if edge_attn is not None and curiosity_feedback is not None:
            try:
                curiosity_result = curiosity_feedback.analyze_and_simulate(
                    edge_attn, graph.edge_index, graph, epoch
                )
                if curiosity_result:
                    fid_loss = curiosity_result.final_fidelity_loss
                    sim_loss = curiosity_result.final_sparsity_loss
                    
                    # Track curiosity loop metrics
                    total_curiosity_iterations += len(curiosity_result.iterations)
                    if curiosity_result.convergence_achieved:
                        total_convergence_count += 1
                    total_resource_cost += curiosity_result.total_resource_cost
                    
                    # Log curiosity loop details for first few graphs
                    if i < 3:
                        logger.debug(f"Graph {i}: {len(curiosity_result.iterations)} iterations, "
                                   f"converged: {curiosity_result.convergence_achieved}, "
                                   f"cost: {curiosity_result.total_resource_cost:.2f}")
                    
            except Exception as e:
                logger.warning(f"Curiosity Loop failed for graph {i}: {e}")
        
        # Combined loss
        total_loss_batch = (cls_loss + 
                           lambda_fid * fid_loss + 
                           lambda_sim * sim_loss)
        
        total_loss_batch.backward()
        optimizer.step()
        
        # Accumulate metrics
        total_loss += total_loss_batch.item()
        total_cls_loss += cls_loss.item()
        total_fid_loss += fid_loss.item() if isinstance(fid_loss, torch.Tensor) else fid_loss
        total_sim_loss += sim_loss.item() if isinstance(sim_loss, torch.Tensor) else sim_loss
        
        # Accuracy calculation (multi-class)
        _, predicted = torch.max(node_logits, 1)
        total_correct += (predicted == graph.y_node.long()).sum().item()
        total_samples += graph.y_node.size(0)
        
        # Update progress bar
        if TQDM_AVAILABLE and hasattr(iterator, 'set_postfix'):
            iterator.set_postfix({
                'Loss': f'{total_loss_batch.item():.4f}',
                'Acc': f'{total_correct / total_samples:.4f}',
                'CurIter': total_curiosity_iterations,
                'Conv': total_convergence_count
            })
    
    # Calculate epoch metrics
    avg_loss = total_loss / len(graphs)
    avg_cls_loss = total_cls_loss / len(graphs)
    avg_fid_loss = total_fid_loss / len(graphs)
    avg_sim_loss = total_sim_loss / len(graphs)
    accuracy = total_correct / total_samples
    
    # Curiosity loop metrics
    avg_curiosity_iterations = total_curiosity_iterations / len(graphs)
    convergence_rate = total_convergence_count / len(graphs) if len(graphs) > 0 else 0.0
    avg_resource_cost = total_resource_cost / len(graphs)
    
    return {
        'loss': avg_loss,
        'cls_loss': avg_cls_loss,
        'fid_loss': avg_fid_loss,
        'sim_loss': avg_sim_loss,
        'accuracy': accuracy,
        'curiosity_iterations': avg_curiosity_iterations,
        'convergence_rate': convergence_rate,
        'resource_cost': avg_resource_cost
    }


def evaluate_model(model: nn.Module, 
                  graphs: List[GraphData],
                  device: torch.device,
                  config: dict) -> Dict[str, float]:
    """Evaluate model on validation/test set."""
    model.eval()
    
    total_correct = 0
    total_samples = 0
    all_predictions = []
    all_labels = []
    
    # Multi-class evaluation
    num_classes = len(config.get('dataset', {}).get('attack_types', ['normal', 'dos', 'probe', 'r2l', 'u2r']))
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    
    with torch.no_grad():
        iterator = graphs
        if TQDM_AVAILABLE:
            iterator = tqdm(graphs, desc="Evaluating")
            
        for graph in iterator:
            graph = to_device(graph, device)
            
            output = model(graph)
            node_logits = output["node_logits"]
            
            _, predicted = torch.max(node_logits, 1)
            labels = graph.y_node.long()
            
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            # Per-class metrics
            for i in range(labels.size(0)):
                label = labels[i].item()
                pred = predicted[i].item()
                
                if label < num_classes:
                    class_total[label] += 1
                    if pred == label:
                        class_correct[label] += 1
            
            # Store for detailed metrics
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = total_correct / total_samples
    
    # Per-class accuracy
    class_accuracies = {}
    attack_types = config.get('dataset', {}).get('attack_types', ['normal', 'dos', 'probe', 'r2l', 'u2r'])
    for i, attack_type in enumerate(attack_types):
        if i < len(class_correct) and class_total[i] > 0:
            class_accuracies[attack_type] = class_correct[i] / class_total[i]
        else:
            class_accuracies[attack_type] = 0.0
    
    # Calculate F1 scores (simplified)
    from sklearn.metrics import f1_score, precision_score, recall_score
    
    try:
        f1 = f1_score(all_labels, all_predictions, average='weighted')
        precision = precision_score(all_labels, all_predictions, average='weighted')
        recall = recall_score(all_labels, all_predictions, average='weighted')
    except:
        f1 = precision = recall = 0.0
    
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'class_accuracies': class_accuracies
    }


def main():
    parser = argparse.ArgumentParser(description='Train GATv2-NS3 IDS on NSL-KDD with Curiosity Loop')
    parser.add_argument('--config', type=str, required=True,
                       help='Path to configuration file')
    parser.add_argument('--nsl_kdd_path', type=str, default='data/nsl_kdd',
                       help='Path to NSL-KDD dataset')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Output directory for results')
    parser.add_argument('--epochs', type=int, default=None,
                       help='Number of epochs (overrides config)')
    parser.add_argument('--disable_simulation', action='store_true',
                       help='Disable NS-3 simulation feedback')
    
    args = parser.parse_args()
    
    # Load configuration
    config = load_config(args.config)
    
    # Override epochs if specified
    if args.epochs is not None:
        config['training']['epochs'] = args.epochs
    
    # Setup
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger = get_logger("nsl_kdd_train")
    
    logger.info("🎯 Starting NSL-KDD Training with Proposal 2: Curiosity Loop")
    logger.info(f"Device: {device}")
    logger.info(f"Config: {args.config}")
    logger.info(f"Output: {args.output_dir}")
    
    # Create output directory
    ensure_dir(args.output_dir)
    
    # Load NSL-KDD dataset
    try:
        dataset_loader = NSLKDDDatasetLoader(args.nsl_kdd_path)
        train_graphs, val_graphs, test_graphs = create_nsl_kdd_graphs(dataset_loader, config)
    except Exception as e:
        logger.error(f"Failed to load NSL-KDD dataset: {e}")
        return
    
    # Create model
    model_config = config['model']
    sample_graph = train_graphs[0]
    
    model = GATv2IDS(
        in_dim_node=sample_graph.x.shape[1],
        in_dim_edge=sample_graph.edge_attr.shape[1] if sample_graph.edge_attr is not None else 0,
        hidden=model_config['hidden'],
        layers=model_config['layers'],
        heads=model_config['heads'],
        num_classes=len(config.get('dataset', {}).get('attack_types', ['normal', 'dos', 'probe', 'r2l', 'u2r'])),
        dropout=model_config['dropout']
    ).to(device)
    
    logger.info(f"Created GATv2IDS model: {model_config}")
    
    # Create optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=float(config['training']['learning_rate']),
        weight_decay=float(config['training']['weight_decay'])
    )
    
    # Create Curiosity Loop feedback system
    curiosity_feedback = None
    if not args.disable_simulation:
        try:
            ns3_client = NS3Client()
            cache = SimCache()
            
            curiosity_feedback = CuriosityLoopFeedback(
                ns3_client=ns3_client,
                cache=cache,
                uncertainty_threshold=config['simulation']['uncertainty_threshold'],
                high_uncertainty_threshold=config['simulation']['high_uncertainty_threshold'],
                forensic_threshold=config['simulation']['forensic_threshold'],
                top_k_edges=config['simulation']['top_k_edges'],
                budget_per_epoch=config['simulation']['budget_per_epoch'],
                max_curiosity_iterations=config['simulation']['max_curiosity_iterations']
            )
            logger.info("Created Curiosity Loop feedback system for NSL-KDD")
        except Exception as e:
            logger.warning(f"Failed to create simulation feedback: {e}")
            curiosity_feedback = None
    else:
        logger.info("Simulation feedback disabled")
    
    # Training loop
    epochs = config['training']['epochs']
    best_val_f1 = 0.0
    patience = config['training']['patience']
    patience_counter = 0
    
    training_history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_accuracy': [],
        'val_f1': [],
        'curiosity_metrics': []
    }
    
    logger.info(f"Starting training for {epochs} epochs")
    
    for epoch in range(epochs):
        epoch_start_time = time.time()
        
        # Reset curiosity loop budget for new epoch
        if curiosity_feedback is not None:
            curiosity_feedback.reset_budget_for_epoch(epoch)
        
        # Training
        train_metrics = train_epoch_with_curiosity_loop(
            model, train_graphs, optimizer, device, curiosity_feedback, config, epoch
        )
        
        # Validation
        val_metrics = evaluate_model(model, val_graphs, device, config)
        
        epoch_time = time.time() - epoch_start_time
        
        # Logging (consistent with train_baselines.py)
        logger.info(f"Epoch {epoch+1}/{epochs}: "
                   f"Loss={train_metrics['loss']:.4f}, "
                   f"Train_Acc={train_metrics['accuracy']:.4f}, "
                   f"Val_Acc={val_metrics['accuracy']:.4f}, "
                   f"Val_F1={val_metrics['f1']:.4f}")
        
        if curiosity_feedback:
            logger.info(f"  Curiosity - Iter: {train_metrics['curiosity_iterations']:.1f}, "
                       f"Conv: {train_metrics['convergence_rate']:.2f}, "
                       f"Cost: {train_metrics['resource_cost']:.2f}")
        
        # Per-class results
        if config.get('logging', {}).get('log_attack_type_metrics', False):
            logger.info("  Per-class accuracy:")
            for attack_type, acc in val_metrics['class_accuracies'].items():
                logger.info(f"    {attack_type}: {acc:.4f}")
        
        # Save training history
        training_history['train_loss'].append(train_metrics['loss'])
        training_history['train_accuracy'].append(train_metrics['accuracy'])
        training_history['val_accuracy'].append(val_metrics['accuracy'])
        training_history['val_f1'].append(val_metrics['f1'])
        
        if curiosity_feedback:
            training_history['curiosity_metrics'].append({
                'iterations': train_metrics['curiosity_iterations'],
                'convergence_rate': train_metrics['convergence_rate'],
                'resource_cost': train_metrics['resource_cost']
            })
        
        # Early stopping
        if val_metrics['f1'] > best_val_f1:
            best_val_f1 = val_metrics['f1']
            patience_counter = 0
            
            # Save best model
            if config.get('logging', {}).get('save_model', True):
                torch.save(model.state_dict(), Path(args.output_dir) / 'best_model.pth')
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            logger.info(f"Early stopping at epoch {epoch+1}")
            break
    
    # Final evaluation on test set
    logger.info("Evaluating on test set...")
    test_metrics = evaluate_model(model, test_graphs, device, config)
    
    logger.info("Final Test Results:")
    logger.info(f"  Accuracy: {test_metrics['accuracy']:.4f}")
    logger.info(f"  F1 Score: {test_metrics['f1']:.4f}")
    logger.info(f"  Precision: {test_metrics['precision']:.4f}")
    logger.info(f"  Recall: {test_metrics['recall']:.4f}")
    
    # Per-class test results
    logger.info("Per-class test accuracy:")
    for attack_type, acc in test_metrics['class_accuracies'].items():
        logger.info(f"  {attack_type}: {acc:.4f}")
    
    # Save results
    results = {
        'config': config,
        'training_history': training_history,
        'test_metrics': test_metrics,
        'best_val_f1': best_val_f1,
        'total_epochs': epoch + 1
    }
    
    with open(Path(args.output_dir) / 'results.json', 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    logger.info(f"Training completed! Results saved to {args.output_dir}")


if __name__ == "__main__":
    main()
