import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm
import argparse
import logging
import os
import shutil
from datetime import datetime
import matplotlib.pyplot as plt
import json
import re
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, precision_recall_curve, auc

# Disable tokenizer parallelism to avoid forking warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Import shared components from utils
from utils import (
    InstructionDataset, TransformerInstructionClassifier, collate_fn,
    setup_logging, check_gpu_availability, get_device, predict_instructions
)

def map_model_name(friendly_name):
    """Map friendly model name to actual HuggingFace model name"""
    model_mapping = {
        'modern-bert-base': 'answerdotai/ModernBERT-base',
        'modern-bert-large': 'answerdotai/ModernBERT-large',
        'xlm-roberta-base': 'xlm-roberta-base',
        'xlm-roberta-large': 'FacebookAI/xlm-roberta-large'
    }
    
    if friendly_name not in model_mapping:
        raise ValueError(f"Unsupported model name: {friendly_name}. Supported models: {list(model_mapping.keys())}")
    
    return model_mapping[friendly_name]

def has_instruction_tags(text):
    """Check if text contains instruction tags"""
    if not text:
        return False
    pattern = r'<instruction>.*?</instruction>'
    return bool(re.search(pattern, text, re.DOTALL))

def calculate_sample_level_metrics(label_texts, predict_texts, sample_probabilities, original_data=None):
    """Calculate sample-level classification metrics
    
    Args:
        label_texts: List of label texts 
        predict_texts: List of predicted texts
        sample_probabilities: List of probabilities for each sample (for PR curve)
        original_data: List of original data dictionaries (optional, for sample_truth field)
    
    Returns:
        dict: Dictionary containing sample-level metrics
    """
    logger = logging.getLogger(__name__)
    
    if len(label_texts) != len(predict_texts):
        logger.error(f"Mismatch: {len(label_texts)} labels vs {len(predict_texts)} predictions")
        return None
    
    # Check if original_data is provided and has the same length
    use_sample_truth = False
    if original_data is not None:
        if len(original_data) == len(label_texts):
            # Check if any sample has the 'sample_truth' field
            sample_truth_available = any('sample_truth' in data for data in original_data)
            if sample_truth_available:
                use_sample_truth = True
                logger.info("Using 'sample_truth' field for sample-level ground truth")
            else:
                logger.info("'sample_truth' field not found in data, using instruction tags method")
        else:
            logger.warning(f"Original data length mismatch: {len(original_data)} vs {len(label_texts)}, using instruction tags method")
    else:
        logger.info("No original data provided, using instruction tags method")
    
    # Calculate classifications
    true_positives = 0
    false_positives = 0
    true_negatives = 0
    false_negatives = 0
    
    sample_labels = []  # Binary labels for each sample (1 if has instructions, 0 if not)
    sample_predictions = []  # Binary predictions for each sample
    
    for i, (label_text, predict_text) in enumerate(zip(label_texts, predict_texts)):
        # Determine ground truth based on available data
        if use_sample_truth and original_data is not None:
            # Use sample_truth field from original data
            sample_truth_value = original_data[i].get('sample_truth', None)
            if sample_truth_value is not None:
                label_has_instructions = bool(sample_truth_value)
            else:
                # Fallback to instruction tags method for this specific sample
                logger.warning(f"Sample {i} missing 'sample_truth' field, falling back to instruction tags")
                label_has_instructions = has_instruction_tags(label_text)
        else:
            # Use existing instruction tags method
            label_has_instructions = has_instruction_tags(label_text)
        
        # Prediction is always based on instruction tags in the predicted text
        predict_has_instructions = has_instruction_tags(predict_text)
        
        # Store binary labels and predictions for PR curve
        sample_labels.append(1 if label_has_instructions else 0)
        sample_predictions.append(1 if predict_has_instructions else 0)
        
        # Calculate confusion matrix
        if label_has_instructions and predict_has_instructions:
            true_positives += 1
        elif label_has_instructions and not predict_has_instructions:
            false_negatives += 1
        elif not label_has_instructions and predict_has_instructions:
            false_positives += 1
        else:  # not label_has_instructions and not predict_has_instructions
            true_negatives += 1
    
    total_samples = len(label_texts)
    
    # Calculate metrics with zero division handling
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    accuracy = (true_positives + true_negatives) / total_samples if total_samples > 0 else 0.0
    
    ground_truth_method = "sample_truth field" if use_sample_truth else "instruction tags"
    logger.info("=== Sample-Level Classification Results ===")
    logger.info(f"Ground truth method: {ground_truth_method}")
    logger.info(f"True Positives: {true_positives}")
    logger.info(f"False Positives: {false_positives}")
    logger.info(f"True Negatives: {true_negatives}")
    logger.info(f"False Negatives: {false_negatives}")
    logger.info(f"Total Samples: {total_samples}")
    logger.info(f"Sample-Level Accuracy: {accuracy:.4f}")
    logger.info(f"Sample-Level Precision: {precision:.4f}")
    logger.info(f"Sample-Level Recall: {recall:.4f}")
    logger.info(f"Sample-Level F1: {f1:.4f}")
    
    return {
        'true_positives': true_positives,
        'false_positives': false_positives,
        'true_negatives': true_negatives,
        'false_negatives': false_negatives,
        'total_samples': total_samples,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'sample_labels': sample_labels,
        'sample_probabilities': sample_probabilities,
        'ground_truth_method': ground_truth_method
    }

def plot_sample_level_precision_recall_curve(sample_labels, sample_probabilities, save_path, utility=None, utility_with_defense=None, total_samples=None):
    """Generate and save sample-level precision-recall curve with attack success rate and utility metrics"""
    logger = logging.getLogger(__name__)
    
    try:
        # Check if we have both classes
        unique_labels = set(sample_labels)
        if len(unique_labels) < 2:
            logger.warning(f"Cannot create sample-level PR curve: only one class present: {unique_labels}")
            return None
        
        # Calculate precision-recall curve
        precision, recall, thresholds = precision_recall_curve(sample_labels, sample_probabilities)
        
        # Calculate AUC
        pr_auc = auc(recall, precision)
        
        # Calculate confusion matrix values using 0.5 threshold
        sample_predictions = [1 if prob >= 0.5 else 0 for prob in sample_probabilities]
        tp = sum(1 for true, pred in zip(sample_labels, sample_predictions) if true == 1 and pred == 1)
        fp = sum(1 for true, pred in zip(sample_labels, sample_predictions) if true == 0 and pred == 1)
        tn = sum(1 for true, pred in zip(sample_labels, sample_predictions) if true == 0 and pred == 0)
        fn = sum(1 for true, pred in zip(sample_labels, sample_predictions) if true == 1 and pred == 0)
        
        # Create the plot
        plt.figure(figsize=(8, 7.5))  # Taller to accommodate multi-line caption
        plt.plot(recall, precision, marker='o', linewidth=2, markersize=4, color='red')
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Sample-Level Precision-Recall Curve\nInstruction Classification', fontsize=14)
        plt.grid(True, alpha=0.3)
        
        # Add AUC score to the plot
        plt.text(0.05, 0.95, f'PR AUC: {pr_auc:.3f}', transform=plt.gca().transAxes, 
                fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))
        
        # Calculate attack success rate (proportion of actual instructions missed)
        total_positive_samples = tp + fn  # Total samples with ground truth = 1
        attack_success = fn / total_positive_samples if total_positive_samples > 0 else 0.0
        
        # Add confusion matrix values, attack success, and utility metrics as caption
        confusion_text = f'Confusion Matrix (threshold=0.5): TP={tp}, FP={fp}, TN={tn}, FN={fn}'
        attack_text = f'Attack Success: {attack_success:.3f} ({fn}/{total_positive_samples})'
        
        # Add utility metrics if available
        if utility is not None and utility_with_defense is not None and total_samples is not None:
            utility_text = f'Utility: {utility}/{total_samples} | Utility w/ Defense: {utility_with_defense}/{total_samples}'
            caption_text = f'{confusion_text}\n{attack_text} | {utility_text}'
        else:
            caption_text = f'{confusion_text} | {attack_text}'
        
        plt.figtext(0.5, 0.02, caption_text, ha='center', fontsize=9, style='italic', 
                   bbox=dict(boxstyle='round,pad=0.4', facecolor='lightgray', alpha=0.7))
        
        # Save the plot
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.15)  # Make room for multi-line caption
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Sample-level Precision-Recall curve saved to: {save_path}")
        logger.info(f"Sample-level PR AUC Score: {pr_auc:.4f}")
        logger.info(f"Confusion Matrix (threshold=0.5): TP={tp}, FP={fp}, TN={tn}, FN={fn}")
        logger.info(f"Attack Success Rate: {attack_success:.4f} ({fn}/{total_positive_samples})")
        
        # Log utility metrics if available
        if utility is not None and utility_with_defense is not None and total_samples is not None:
            logger.info(f"Utility (successful attacks): {utility}/{total_samples}")
            logger.info(f"Utility with Defense (blocked attacks): {utility_with_defense}/{total_samples}")
        
        return pr_auc
        
    except Exception as e:
        logger.error(f"Error creating sample-level precision-recall curve: {e}")
        return None

def _aggregate_windows(windows, logger, debug_enabled=False):
    """Aggregate predictions from multiple overlapping windows for a single sample"""
    if len(windows) == 1:
        # Single window, return as-is
        window = windows[0]
        tokens = window['tokens']
        preds = window['predictions']
        labels = window['labels']
        probs = window['probabilities']
        
        min_len = min(len(tokens), len(preds), len(labels), len(probs))
        if min_len < len(tokens):
            logger.warning(f"Single window length mismatch: tokens={len(tokens)}, preds={len(preds)}, labels={len(labels)}, probs={len(probs)}")
        
        return {
            'tokens': tokens[:min_len],
            'predictions': preds[:min_len],
            'labels': labels[:min_len],
            'probabilities': probs[:min_len]
        }
    
    # Sort windows by subword start position to maintain order
    windows.sort(key=lambda w: w['subword_start'])
    
    # Start with the first window
    aggregated_tokens = list(windows[0]['tokens'])
    aggregated_predictions = list(windows[0]['predictions'])
    aggregated_labels = list(windows[0]['labels'])
    aggregated_probabilities = list(windows[0]['probabilities'])
    
    # Process each subsequent window
    for i, window in enumerate(windows[1:], 1):
        window_tokens = window['tokens']
        window_predictions = window['predictions']
        window_labels = window['labels']
        window_probabilities = window['probabilities']
        
        # Find overlap with current aggregated sequence
        overlap_length = 0
        max_possible_overlap = min(len(aggregated_tokens), len(window_tokens))
        
        # Try different overlap lengths to find the best match
        for j in range(1, max_possible_overlap + 1):
            agg_suffix = aggregated_tokens[-j:]
            win_prefix = window_tokens[:j]
            
            if agg_suffix == win_prefix:
                overlap_length = j
        
        if overlap_length > 0:
            # Merge predictions in overlap region using majority voting
            for j in range(overlap_length):
                agg_idx = len(aggregated_tokens) - overlap_length + j
                win_idx = j
                
                # Majority vote for predictions
                agg_pred = aggregated_predictions[agg_idx]
                win_pred = window_predictions[win_idx]
                
                # Simple majority vote (could be enhanced with confidence weighting)
                if agg_pred == win_pred:
                    # Predictions agree, average probabilities
                    aggregated_probabilities[agg_idx] = (aggregated_probabilities[agg_idx] + window_probabilities[win_idx]) / 2
                else:
                    # Predictions disagree, choose the one with higher probability
                    if window_probabilities[win_idx] > aggregated_probabilities[agg_idx]:
                        aggregated_predictions[agg_idx] = win_pred
                        aggregated_probabilities[agg_idx] = window_probabilities[win_idx]
            
            # Append non-overlapping part of current window
            non_overlap_start = overlap_length
            aggregated_tokens.extend(window_tokens[non_overlap_start:])
            aggregated_predictions.extend(window_predictions[non_overlap_start:])
            aggregated_labels.extend(window_labels[non_overlap_start:])
            aggregated_probabilities.extend(window_probabilities[non_overlap_start:])
        
        else:
            # No overlap found, this suggests a gap - log warning but continue
            logger.warning(f"No overlap found between aggregated sequence and window {i}. This may indicate a gap in coverage.")
            logger.warning(f"Last 5 aggregated tokens: {aggregated_tokens[-5:] if len(aggregated_tokens) >= 5 else aggregated_tokens}")
            logger.warning(f"First 5 window tokens: {window_tokens[:5] if len(window_tokens) >= 5 else window_tokens}")
            
            # Append with potential gap
            aggregated_tokens.extend(window_tokens)
            aggregated_predictions.extend(window_predictions)
            aggregated_labels.extend(window_labels)
            aggregated_probabilities.extend(window_probabilities)
    
    # Ensure all arrays have the same length
    min_len = min(len(aggregated_tokens), len(aggregated_predictions), 
                  len(aggregated_labels), len(aggregated_probabilities))
    
    if min_len < len(aggregated_tokens):
        logger.warning(f"Length mismatch in aggregation, truncating to {min_len}")
    
    return {
        'tokens': aggregated_tokens[:min_len],
        'predictions': aggregated_predictions[:min_len],
        'labels': aggregated_labels[:min_len],
        'probabilities': aggregated_probabilities[:min_len]
    }

def evaluate_model(model, dataloader, device, tokenizer, max_length=512, save_predictions=False, token_threshold=None):
    """Evaluate the model and return comprehensive metrics including validation loss using BERT approach"""
    logger = logging.getLogger(__name__)
    model.eval()
    
    # Track window-level results for aggregation
    window_results = {}  # sample_id -> list of window results
    
    # Track aggregated sample-level results
    all_predictions = []
    all_labels = []
    all_token_predictions = []
    all_token_labels = []
    all_token_probabilities = []  # For precision-recall curve
    all_sample_probabilities = []  # For sample-level precision-recall curve
    
    # For saving predictions
    sample_predictions = [] if save_predictions else None
    
    # Track validation loss
    total_loss = 0
    valid_batches = 0
    
    sequence_correct = 0
    total_sequences = 0
    total_windows = 0
    samples_with_multiple_windows = 0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating", unit="batch")
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs['loss']
            logits = outputs['logits']
            
            # Skip batch if loss is NaN
            if torch.isnan(loss):
                logger.warning("Skipping batch with NaN loss during evaluation")
                continue
            
            total_loss += loss.item()
            valid_batches += 1
            
            # Get predictions and probabilities
            probabilities = torch.softmax(logits, dim=-1)
            
            # Use threshold-based prediction if token_threshold is provided, otherwise use argmax
            if token_threshold is not None:
                # Use threshold: predict 1 if probability of class 1 >= threshold, else 0
                predictions = (probabilities[:, :, 1] >= token_threshold).long()
            else:
                # Use argmax (original behavior)
                predictions = torch.argmax(logits, dim=-1)
            
            # Process each window in the batch
            for i in range(len(batch['original_tokens'])):
                original_tokens = batch['original_tokens'][i]  # Words for reference
                original_labels = batch['original_labels'][i]  # Word-level labels for reference
                sample_id = batch['sample_ids'][i]
                window_id = batch['window_ids'][i]
                total_windows_for_sample = batch['total_windows'][i]
                window_start = batch['window_starts'][i]
                window_end = batch['window_ends'][i]
                
                # Extract actual subword tokens and predictions for consistent aggregation
                input_ids_tensor = input_ids[i]
                attention_mask_tensor = attention_mask[i]
                window_labels_tensor = labels[i]
                
                # Find actual sequence length (excluding padding)
                actual_length = attention_mask_tensor.sum().item()
                
                # Extract actual subword tokens using tokenizer
                actual_subword_ids = input_ids_tensor[:actual_length].tolist()
                actual_subword_tokens = tokenizer.convert_ids_to_tokens(actual_subword_ids)
                
                # Extract subword predictions and probabilities (same length as subword tokens)
                subword_predictions = predictions[i][:actual_length].tolist()
                subword_probabilities = probabilities[i][:actual_length, 1].tolist()  # Probability of class 1
                
                # Extract subword labels (for debugging - contains -100 for non-first subtokens)
                subword_labels_debug = window_labels_tensor[:actual_length].tolist()
                
                # Convert subword predictions to word-level predictions using BERT approach
                try:
                    word_predictions = []
                    word_probabilities = []
                    word_labels = []
                    
                    # Use BERT approach: only first subtoken of each word carries the label
                    extracted_word_count = 0
                    for j in range(len(window_labels_tensor)):
                        true_label = window_labels_tensor[j].item()
                        if true_label != -100:  # This is a first subtoken of a word
                            word_predictions.append(predictions[i][j].item())
                            word_probabilities.append(probabilities[i][j][1].item())
                            word_labels.append(true_label)
                            extracted_word_count += 1
                    
                    # Handle missing words with enhanced alignment
                    missing_words = len(original_tokens) - extracted_word_count
                    if missing_words > 0:
                        # For missing words, use a fallback strategy:
                        # Look for remaining non-special, non-padding tokens and assign default predictions
                        remaining_positions = []
                        for j in range(len(window_labels_tensor)):
                            if (window_labels_tensor[j].item() == -100 and 
                                j > 0 and j < len(window_labels_tensor) - 1 and  # Not first/last (special tokens)
                                attention_mask_tensor[j].item() == 1):  # Not padding
                                remaining_positions.append(j)
                        
                        # Use the first few remaining positions for missing words
                        for k in range(min(missing_words, len(remaining_positions))):
                            pos = remaining_positions[k]
                            word_predictions.append(predictions[i][pos].item())
                            word_probabilities.append(probabilities[i][pos][1].item())
                            # For missing word labels, use the previous label or default to 0
                            default_label = word_labels[-1] if word_labels else 0
                            word_labels.append(default_label)
                            extracted_word_count += 1
                        
                        # If we still have missing words, pad with defaults
                        while extracted_word_count < len(original_tokens):
                            word_predictions.append(0)  # Default to "OTHER"
                            word_probabilities.append(0.5)  # Neutral probability
                            word_labels.append(0)  # Default to "OTHER"
                            extracted_word_count += 1
                    
                    # Ensure perfect alignment now
                    assert len(word_predictions) == len(original_tokens), f"Word alignment failed: {len(word_predictions)} != {len(original_tokens)}"
                    assert len(word_probabilities) == len(original_tokens), f"Word prob alignment failed: {len(word_probabilities)} != {len(original_tokens)}"
                    assert len(word_labels) == len(original_tokens), f"Word label alignment failed: {len(word_labels)} != {len(original_tokens)}"
                    
                except Exception as e:
                    logger.warning(f"Word extraction failed for sample {sample_id}: {e}")
                    # Fallback: truncate to minimum length to avoid crashes
                    min_word_length = min(len(original_tokens), len(word_predictions) if 'word_predictions' in locals() else 0)
                    
                    if min_word_length < len(original_tokens):
                        logger.warning(f"Word extraction length mismatch for sample {sample_id}: "
                                     f"tokens={len(original_tokens)}, preds={len(word_predictions)}, "
                                     f"labels={len(word_labels)}, probs={len(word_probabilities)} - truncating to {min_word_length}")
                    
                    word_predictions = word_predictions[:min_word_length]
                    word_probabilities = word_probabilities[:min_word_length]
                    word_labels = word_labels[:min_word_length]
                    original_tokens = original_tokens[:min_word_length]
                
                # Store window result for aggregation
                if sample_id not in window_results:
                    window_results[sample_id] = []
                
                window_results[sample_id].append({
                    'window_id': window_id,
                    'subword_start': window_start,
                    'subword_end': window_end,
                    'total_windows': total_windows_for_sample,
                    # Store subword tokens for consistent aggregation
                    'tokens': actual_subword_tokens,
                    'predictions': subword_predictions,
                    'labels': subword_labels_debug,  # Keep for debugging
                    'probabilities': subword_probabilities,
                    # Store word-level data separately for final metrics
                    'word_tokens': original_tokens,
                    'word_predictions': word_predictions,
                    'word_labels': word_labels,
                    'word_probabilities': word_probabilities,
                    'sample_id': sample_id
                })
                
                total_windows += 1
        
        # Aggregate results from multiple windows per sample
        logger.info(f"Aggregating results from {total_windows} windows across {len(window_results)} samples")
        
        for sample_id, windows in window_results.items():
            if len(windows) > 1:
                samples_with_multiple_windows += 1
            
            # Aggregate windows for this sample
            aggregated_result = _aggregate_windows(windows, logger)
            
            # Convert aggregated subword results back to word-level predictions
            # Combine all word-level data from windows (these should be consistent across windows)
            all_word_tokens = []
            all_word_labels = []
            
            for window in windows:
                # Only add unique word tokens (avoid duplicates from overlapping windows)
                for j, word_token in enumerate(window['word_tokens']):
                    if word_token not in all_word_tokens:
                        all_word_tokens.append(word_token)
                        all_word_labels.append(window['word_labels'][j])
            
            # For predictions and probabilities, we need to map from the aggregated subword sequence
            # back to word-level predictions. This is complex, so let's use a simpler approach:
            # Take word-level predictions from the first window and merge overlaps manually
            
            if len(windows) == 1:
                # Single window - use word-level data directly
                word_predictions = windows[0]['word_predictions']
                word_probabilities = windows[0]['word_probabilities']
                word_tokens = windows[0]['word_tokens']
                word_labels = windows[0]['word_labels']
            else:
                # Multiple windows - take word data from first window as base
                # This is a simplification but maintains consistency
                word_predictions = windows[0]['word_predictions']
                word_probabilities = windows[0]['word_probabilities']
                word_tokens = windows[0]['word_tokens']
                word_labels = windows[0]['word_labels']
                
                # TODO: Could implement proper word-level aggregation here if needed
                # For now, this ensures we have consistent data for metrics
            
            # Store word-level results for final metrics
            all_predictions.append(word_predictions)
            all_labels.append(word_labels)
            
            # Ensure all arrays have the same length before extending
            if len(word_predictions) == len(word_labels) == len(word_probabilities):
                all_token_predictions.extend(word_predictions)
                all_token_labels.extend(word_labels)
                all_token_probabilities.extend(word_probabilities)
            else:
                logger.warning(f"Length mismatch for sample {sample_id}: "
                             f"preds={len(word_predictions)}, labels={len(word_labels)}, probs={len(word_probabilities)}")
                # Take the minimum length to avoid crashes
                min_len = min(len(word_predictions), len(word_labels), len(word_probabilities))
                all_token_predictions.extend(word_predictions[:min_len])
                all_token_labels.extend(word_labels[:min_len])
                all_token_probabilities.extend(word_probabilities[:min_len])
            
            # Calculate sample-level probability (max instruction token probability)
            sample_prob = max(word_probabilities) if word_probabilities else 0.0
            all_sample_probabilities.append(sample_prob)
            
            # Store per-sample predictions if requested
            if save_predictions and sample_predictions is not None:
                sample_predictions.append({
                    'tokens': word_tokens,
                    'predictions': word_predictions,
                    'true_labels': word_labels
                })
            
            # Check sequence-level accuracy
            if word_predictions == word_labels:
                sequence_correct += 1
            
            total_sequences += 1
    
    # Debug information for validation sets
    logger.info(f"=== EVALUATION SUMMARY ===")
    logger.info(f"Total multi-window samples processed: {samples_with_multiple_windows}")
    logger.info(f"Total words in evaluation: {len(all_token_labels)} (BERT approach: first subtokens only)")
    logger.info(f"Word label distribution: {np.bincount(all_token_labels) if all_token_labels else 'No words'}")
    logger.info(f"Word prediction distribution: {np.bincount(all_token_predictions) if all_token_predictions else 'No predictions'}")
    logger.info(f"Sliding window statistics: {total_windows} windows processed for {total_sequences} samples, {samples_with_multiple_windows} samples had multiple windows")
    
    # Calculate metrics with error handling for small datasets
    if len(all_token_labels) == 0 or len(all_token_predictions) == 0:
        logger.warning("No words found for evaluation!")
        validation_loss = total_loss / valid_batches if valid_batches > 0 else float('inf')
        return {
            'token_accuracy': 0.0,
            'token_precision': 0.0,
            'token_recall': 0.0,
            'token_f1': 0.0,
            'sequence_accuracy': 0.0,
            'total_sequences': total_sequences,
            'validation_loss': validation_loss,
            'token_labels': [],
            'token_probabilities': [],
            'sample_probabilities': [],
            'sample_predictions': sample_predictions
        }
    
    # Check if we have both classes
    unique_labels = set(all_token_labels)
    unique_predictions = set(all_token_predictions)
    
    if len(unique_labels) < 2:
        logger.warning(f"Only one class in validation labels: {unique_labels}")
    if len(unique_predictions) < 2:
        logger.warning(f"Only one class in predictions: {unique_predictions}")
    
    # Calculate word-level metrics using BERT approach
    try:
        from sklearn.metrics import accuracy_score, precision_recall_fscore_support
        
        # Word-level metrics
        token_accuracy = accuracy_score(all_token_labels, all_token_predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_token_labels, all_token_predictions, average='binary', pos_label=1, zero_division='warn'
        )
        
        # Sequence-level accuracy
        sequence_accuracy = sequence_correct / total_sequences if total_sequences > 0 else 0.0
        
        # Validation loss
        validation_loss = total_loss / valid_batches if valid_batches > 0 else float('inf')
        
    except Exception as e:
        logger.error(f"Error calculating metrics: {e}")
        validation_loss = total_loss / valid_batches if valid_batches > 0 else float('inf')
        return {
            'token_accuracy': 0.0,
            'token_precision': 0.0,
            'token_recall': 0.0,
            'token_f1': 0.0,
            'sequence_accuracy': 0.0,
            'total_sequences': total_sequences,
            'validation_loss': validation_loss,
            'token_labels': all_token_labels,
            'token_probabilities': all_token_probabilities,
            'sample_probabilities': all_sample_probabilities,
            'sample_predictions': sample_predictions
        }
    
    return {
        'token_accuracy': token_accuracy,
        'token_precision': precision,
        'token_recall': recall,
        'token_f1': f1,
        'sequence_accuracy': sequence_accuracy,
        'total_sequences': total_sequences,
        'validation_loss': validation_loss,
        'token_labels': all_token_labels,
        'token_probabilities': all_token_probabilities,
        'sample_probabilities': all_sample_probabilities,
        'sample_predictions': sample_predictions
    }

def reconstruct_text_with_tags(tokens, predictions):
    """Reconstruct text from tokens and predictions, adding instruction tags"""
    logger = logging.getLogger(__name__)
    
    if len(tokens) != len(predictions):
        logger.warning(f"Length mismatch: tokens ({len(tokens)}) vs predictions ({len(predictions)})")
        # Truncate to the shorter length to avoid crashes
        min_length = min(len(tokens), len(predictions))
        tokens = tokens[:min_length]
        predictions = predictions[:min_length]
        logger.warning(f"Truncated both to length {min_length}")
    
    result_parts = []
    current_instruction = []
    
    for token, pred in zip(tokens, predictions):
        if pred == 1:  # INSTRUCTION
            current_instruction.append(token)
        else:  # OTHER
            # If we were building an instruction, close it
            if current_instruction:
                instruction_text = ' '.join(current_instruction)
                result_parts.append(f'<instruction>{instruction_text}</instruction>')
                current_instruction = []
            
            # Add the non-instruction token
            result_parts.append(token)
    
    # Handle case where text ends with an instruction
    if current_instruction:
        instruction_text = ' '.join(current_instruction)
        result_parts.append(f'<instruction>{instruction_text}</instruction>')
    
    # Join with spaces, but be careful about spacing around tags
    result = ' '.join(result_parts)
    
    # Clean up spacing around tags (optional - for cleaner output)
    result = result.replace(' <instruction>', ' <instruction>')
    result = result.replace('</instruction> ', '</instruction> ')
    
    return result

def load_original_data(data_path):
    """Load original JSONL data to preserve original fields"""
    logger = logging.getLogger(__name__)
    original_data = []
    
    logger.info(f"Loading original data from: {data_path}")
    
    with open(data_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                data = json.loads(line.strip())
                original_data.append(data)
            except Exception as e:
                logger.error(f"Error loading original data line {line_num}: {e}")
    
    logger.info(f"Loaded {len(original_data)} original samples")
    return original_data

def save_predictions_as_jsonl(original_data, sample_predictions, output_path):
    """Save predictions as JSONL with original data + predict_text + sample_prediction"""
    logger = logging.getLogger(__name__)
    
    logger.info(f"Saving predictions to: {output_path}")
    
    # Filter original data to only include samples that passed sanity check
    # (this matches what the InstructionDataset does)
    filtered_original_data = []
    for data in original_data:
        sanity_check = data.get('sanity_check', False)
        if sanity_check is not False:  # Include True and any truthy values
            filtered_original_data.append(data)
    
    logger.info(f"Original data: {len(original_data)} samples")
    logger.info(f"Filtered data (passed sanity check): {len(filtered_original_data)} samples")
    logger.info(f"Predictions: {len(sample_predictions)} samples")
    
    if len(filtered_original_data) != len(sample_predictions):
        logger.error(f"Mismatch: {len(filtered_original_data)} filtered samples vs {len(sample_predictions)} predictions")
        logger.error("This might indicate an issue with data processing. Attempting to save what we can...")
        
        # Use the minimum to avoid index errors
        min_length = min(len(filtered_original_data), len(sample_predictions))
        filtered_original_data = filtered_original_data[:min_length]
        sample_predictions = sample_predictions[:min_length]
    
    saved_count = 0
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for orig_data, sample_pred in zip(filtered_original_data, sample_predictions):
            try:
                # Reconstruct text with instruction tags
                predict_text = reconstruct_text_with_tags(
                    sample_pred['tokens'], 
                    sample_pred['predictions']
                )
                
                # Determine sample-level prediction based on presence of instruction tags
                sample_prediction = has_instruction_tags(predict_text)
                
                # Create new record with original data + predict_text + sample_prediction
                result_record = orig_data.copy()
                result_record['predict_text'] = predict_text
                result_record['sample_prediction'] = sample_prediction
                
                # Write to file
                f.write(json.dumps(result_record, ensure_ascii=False) + '\n')
                saved_count += 1
                
            except Exception as e:
                logger.error(f"Error processing sample {orig_data.get('id', 'unknown')}: {e}")
    
    logger.info(f"Successfully saved predictions for {saved_count} samples to {output_path}")
    return saved_count

def save_metrics_as_json(metrics, sample_level_metrics, model_name, training_dataset, test_dataset, save_path):
    """Save evaluation metrics as JSON file
    
    Args:
        metrics: Dictionary containing word-level metrics from evaluate_model
        sample_level_metrics: Dictionary containing sample-level metrics
        model_name: Name of the model used
        training_dataset: Name/path of the training dataset
        test_dataset: Name/path of the test dataset
        save_path: Path where to save the JSON file
    """
    logger = logging.getLogger(__name__)
    
    try:
        # Create the metrics summary
        metrics_summary = {
            "model_name": model_name,
            "training_dataset": training_dataset,
            "test_dataset": test_dataset,
            "word_level_metrics": {
                "accuracy": float(metrics.get('token_accuracy', 0.0)),
                "precision": float(metrics.get('token_precision', 0.0)),
                "recall": float(metrics.get('token_recall', 0.0)),
                "f1": float(metrics.get('token_f1', 0.0)),
                "pr_auc": float(metrics.get('word_pr_auc', 0.0))
            },
            "sample_level_metrics": {
                "accuracy": 0.0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0,
                "pr_auc": 0.0
            }
        }
        
        # Add sample-level metrics if available
        if sample_level_metrics is not None:
            metrics_summary["sample_level_metrics"] = {
                "accuracy": float(sample_level_metrics.get('accuracy', 0.0)),
                "precision": float(sample_level_metrics.get('precision', 0.0)),
                "recall": float(sample_level_metrics.get('recall', 0.0)),
                "f1": float(sample_level_metrics.get('f1', 0.0)),
                "pr_auc": float(sample_level_metrics.get('pr_auc', 0.0))
            }
        
        # Save the JSON file
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(metrics_summary, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Metrics saved to JSON file: {save_path}")
        logger.info("JSON metrics summary:")
        logger.info(f"  Model: {model_name}")
        logger.info(f"  Training dataset: {training_dataset}")
        logger.info(f"  Test dataset: {test_dataset}")
        logger.info(f"  Word-level F1: {metrics_summary['word_level_metrics']['f1']:.4f}")
        logger.info(f"  Sample-level F1: {metrics_summary['sample_level_metrics']['f1']:.4f}")
        
        return metrics_summary
        
    except Exception as e:
        logger.error(f"Error saving metrics to JSON: {e}")
        return None

def plot_precision_recall_curve(y_true, y_probabilities, save_path='data/logs/precision_recall_curve.png'):
    """Generate and save precision-recall curve"""
    logger = logging.getLogger(__name__)
    
    try:
        # Check if we have both classes
        unique_labels = set(y_true)
        if len(unique_labels) < 2:
            logger.warning(f"Cannot create PR curve: only one class present: {unique_labels}")
            return
        
        # Calculate precision-recall curve
        precision, recall, thresholds = precision_recall_curve(y_true, y_probabilities)
        
        # Create the plot
        plt.figure(figsize=(8, 6))
        plt.plot(recall, precision, marker='o', linewidth=2, markersize=4)
        plt.xlabel('Recall', fontsize=12)
        plt.ylabel('Precision', fontsize=12)
        plt.title('Word-Level Precision-Recall Curve\nInstruction Word Classification (BERT Approach)', fontsize=14)
        plt.grid(True, alpha=0.3)
        
        # Add some statistics to the plot
        pr_auc = auc(recall, precision)
        plt.text(0.05, 0.95, f'PR AUC: {pr_auc:.3f}', transform=plt.gca().transAxes, 
                fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        # Save the plot
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Word-level Precision-Recall curve saved to: {save_path}")
        logger.info(f"Word-level PR AUC Score: {pr_auc:.4f}")
        
        return pr_auc
        
    except Exception as e:
        logger.error(f"Error creating word-level precision-recall curve: {e}")
        return None

def load_model_and_tokenizer(model_path, tokenizer_path, model_name='xlm-roberta-base', 
                            loss_type='standard', dropout=0.1, device=None):
    """Load a trained model and tokenizer
    
    Args:
        model_path: Path to the saved model (.pth file)
        tokenizer_path: Path to the saved tokenizer directory
        model_name: Pre-trained model name used during training
        loss_type: Loss type used during training ('standard', 'weighted_ce', 'focal')
        device: Device to load model on (auto-detected if None)
    """
    logger = logging.getLogger(__name__)
    
    if device is None:
        device = get_device()
    
    # Load tokenizer
    logger.info(f"Loading tokenizer from: {tokenizer_path}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    
    # Load model
    logger.info(f"Loading model from: {model_path}")
    logger.info(f"Using loss type: {loss_type} (Note: loss function not used during evaluation)")
    
    # Create model instance - class_weights will be ignored during evaluation anyway
    model = TransformerInstructionClassifier(model_name, class_weights=None, loss_type=loss_type, dropout=dropout)
    
    # Load state dict
    state_dict = torch.load(model_path, map_location=device)
    
    # Filter out loss function related keys that might cause issues
    filtered_state_dict = {}
    skipped_keys = []
    
    for key, value in state_dict.items():
        if key.startswith('loss_fct'):
            skipped_keys.append(key)
        else:
            filtered_state_dict[key] = value
    
    if skipped_keys:
        logger.info(f"Skipping loss function parameters: {skipped_keys}")
        logger.info("(These are not needed for evaluation)")
    
    # Load the filtered state dict
    missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)
    
    if missing_keys:
        logger.warning(f"Missing keys when loading model: {missing_keys}")
    if unexpected_keys:
        logger.warning(f"Unexpected keys when loading model: {unexpected_keys}")
    
    model.to(device)
    model.eval()
    
    logger.info(f"Model and tokenizer loaded successfully on device: {device}")
    return model, tokenizer

def run_evaluation(model_path, tokenizer_path, data_path, 
                  model_name='xlm-roberta-base', 
                  batch_size=16, max_length=512, overlap=256, output_dir='data/logs', loss_type='standard',
                  dropout=0.1, save_predictions=True, training_dataset_name=None, token_threshold=None):
    """Run complete evaluation pipeline"""
    logger = logging.getLogger(__name__)
    
    # Get device
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path, model_name=model_name, 
                                               loss_type=loss_type, dropout=dropout, device=device)
    
    # Load dataset with sliding windows
    logger.info(f"Loading evaluation dataset from: {data_path}")
    eval_dataset = InstructionDataset(data_path, tokenizer, max_length, is_training=False, 
                                    window_size=max_length, overlap=overlap)
    
    if len(eval_dataset) == 0:
        logger.error("No valid samples found in evaluation dataset!")
        return None
    
    logger.info(f"Loaded {len(eval_dataset)} samples for evaluation")
    
    # Create data loader
    eval_loader = DataLoader(
        eval_dataset, batch_size=batch_size, shuffle=False, 
        collate_fn=collate_fn, num_workers=0
    )
    
    # Run evaluation
    logger.info("Starting evaluation...")
    metrics = evaluate_model(model, eval_loader, device, tokenizer, max_length, save_predictions, token_threshold)
    
    # Log results
    logger.info("=== Word-Level Evaluation Results (BERT Approach) ===")
    logger.info(f"Word Accuracy: {metrics['token_accuracy']:.4f}")
    logger.info(f"Word Precision: {metrics['token_precision']:.4f}")
    logger.info(f"Word Recall: {metrics['token_recall']:.4f}")
    logger.info(f"Word F1: {metrics['token_f1']:.4f}")
    logger.info(f"Sequence Accuracy: {metrics['sequence_accuracy']:.4f}")
    logger.info(f"Total Sequences: {metrics['total_sequences']}")
    
    # Calculate sample-level metrics (requires save_predictions=True)
    sample_level_metrics = None
    if save_predictions and metrics['sample_predictions'] is not None:
        # Load original data to get label_text
        original_data = load_original_data(data_path)
        
        # Filter original data to match what was processed (sanity check passed)
        filtered_original_data = []
        for data in original_data:
            sanity_check = data.get('sanity_check', False)
            if sanity_check is not False:
                filtered_original_data.append(data)
        
        # Extract label texts and predict texts
        label_texts = [data.get('label_text', '') for data in filtered_original_data[:len(metrics['sample_predictions'])]]
        predict_texts = []
        
        for sample_pred in metrics['sample_predictions']:
            predict_text = reconstruct_text_with_tags(
                sample_pred['tokens'], 
                sample_pred['predictions']
            )
            predict_texts.append(predict_text)
        
        # Calculate utility metrics
        utility = 0  # Count of samples with metadata.success = true OR utility = true
        utility_with_defense = 0  # Count with metadata.success = true (or utility = true) AND sample_prediction = false
        
        for i, (data, predict_text) in enumerate(zip(filtered_original_data[:len(predict_texts)], predict_texts)):
            # Check if sample has successful attack
            # First try metadata.success, then fallback to utility field
            has_successful_attack = data.get('metadata', {}).get('success', False)
            if not has_successful_attack:
                # Fallback to utility field if metadata.success doesn't exist or is False
                has_successful_attack = data.get('utility', False)
            
            if has_successful_attack:
                utility += 1
                
                # Check if defense would have blocked it (sample_prediction = false)
                sample_prediction = has_instruction_tags(predict_text)
                if not sample_prediction:  # Defense blocked it
                    utility_with_defense += 1
        
        logger.info(f"=== Utility Metrics ===")
        logger.info(f"Utility (successful attacks): {utility}/{len(predict_texts)}")
        logger.info(f"Utility with Defense (blocked attacks): {utility_with_defense}/{len(predict_texts)}")
        
        # Calculate sample-level metrics
        sample_level_metrics = calculate_sample_level_metrics(
            label_texts, 
            predict_texts, 
            metrics['sample_probabilities'][:len(predict_texts)],
            original_data=filtered_original_data # Pass filtered_original_data for sample_truth
        )
    
    # Save predictions if requested
    if save_predictions and metrics['sample_predictions'] is not None:
        logger.info("=== Saving Predictions ===")
        
        # Load original data to preserve all fields
        original_data = load_original_data(data_path)
        
        # Create predictions output file
        predictions_path = os.path.join(output_dir, 'predictions.jsonl')
        
        # Save predictions as JSONL
        saved_count = save_predictions_as_jsonl(original_data, metrics['sample_predictions'], predictions_path)
    
    # Generate word-level precision-recall curve
    if len(metrics['token_labels']) > 0 and len(metrics['token_probabilities']) > 0:
        pr_curve_path = os.path.join(output_dir, 'word_precision_recall_curve.png')
        word_pr_auc = plot_precision_recall_curve(
            metrics['token_labels'], 
            metrics['token_probabilities'], 
            save_path=pr_curve_path
        )
        if word_pr_auc is not None:
            metrics['word_pr_auc'] = word_pr_auc
    else:
        logger.warning("Cannot generate word-level precision-recall curve: insufficient data")
    
    # Generate sample-level precision-recall curve
    if sample_level_metrics is not None and len(sample_level_metrics['sample_labels']) > 0:
        sample_pr_curve_path = os.path.join(output_dir, 'sample_precision_recall_curve.png')
        # Pass utility metrics if they were calculated
        utility_val = utility if 'utility' in locals() else None
        utility_defense_val = utility_with_defense if 'utility_with_defense' in locals() else None
        total_samples = len(predict_texts) if 'predict_texts' in locals() else len(sample_level_metrics['sample_labels'])
        
        sample_pr_auc = plot_sample_level_precision_recall_curve(
            sample_level_metrics['sample_labels'],
            sample_level_metrics['sample_probabilities'],
            save_path=sample_pr_curve_path,
            utility=utility_val,
            utility_with_defense=utility_defense_val,
            total_samples=total_samples
        )
        
        # Add AUC to the metrics
        if sample_pr_auc is not None:
            sample_level_metrics['pr_auc'] = sample_pr_auc
    else:
        logger.warning("Cannot generate sample-level precision-recall curve: insufficient data")
    
    # Add sample-level metrics to return value
    if sample_level_metrics is not None:
        metrics['sample_level_metrics'] = sample_level_metrics
    
    # Save metrics as JSON file
    logger.info("=== Saving Metrics Summary ===")
    
    # Extract dataset names
    test_dataset_name = os.path.basename(data_path) if data_path else "unknown_test_dataset"
    if training_dataset_name is None:
        # Try to infer training dataset name from the output directory or use default
        training_dataset_name = "unknown_training_dataset"
    
    # Save metrics JSON file
    metrics_json_path = os.path.join(output_dir, 'metrics_summary.json')
    save_metrics_as_json(
        metrics=metrics,
        sample_level_metrics=sample_level_metrics,
        model_name=model_name,
        training_dataset=training_dataset_name,
        test_dataset=test_dataset_name,
        save_path=metrics_json_path
    )
    
    return metrics

def run_sample_predictions(model_path, tokenizer_path, model_name='xlm-roberta-base', 
                          loss_type='standard', dropout=0.1):
    """Run sample predictions to demonstrate model capabilities"""
    logger = logging.getLogger(__name__)
    
    # Get device
    device = get_device()
    
    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path, model_name=model_name, 
                                               loss_type=loss_type, dropout=dropout, device=device)
    
    # Sample texts for prediction
    sample_texts = [
        "Please click on the Submit button and then fill out the form.",
        "I am a student. Please don't sit on the grass, Rita!",
        "Can you help me find the settings page?",
        "The weather is nice today. Let's go for a walk.",
        "Navigate to the menu and select the export option.",
        "This is a simple sentence without any instructions.",

    ]
    
    logger.info("=== Sample Predictions ===")
    for i, text in enumerate(sample_texts, 1):
        tokens, predictions = predict_instructions(model, tokenizer, text, device)
        
        logger.info(f"\nSample {i}: {text}")
        logger.info("Token predictions:")
        for token, pred in zip(tokens, predictions):
            label = "INSTRUCTION" if pred == 1 else "OTHER"
            logger.info(f"  {token}: {label}")

def find_jsonl_files(directory_path):
    """Find all JSONL files in a directory"""
    jsonl_files = []
    for filename in os.listdir(directory_path):
        if filename.endswith('.jsonl'):
            full_path = os.path.join(directory_path, filename)
            jsonl_files.append(full_path)
    return sorted(jsonl_files)


def combine_metrics_and_plots(individual_results, combined_output_dir):
    """Combine metrics from multiple evaluations and create combined plots"""
    logger = logging.getLogger(__name__)
    
    # Prepare combined data
    all_datasets = []
    combined_sample_metrics = {}
    combined_word_metrics = {}
    combined_pr_data = {'word': {'precision': [], 'recall': [], 'thresholds': [], 'auc': []},
                       'sample': {'precision': [], 'recall': [], 'thresholds': [], 'auc': []}}
    
    # Collect data from all individual results
    for dataset_name, result in individual_results.items():
        all_datasets.append(dataset_name)
        
        # Extract and store individual metrics for combined reporting
        # Word-level metrics from the main metrics dictionary
        combined_word_metrics[dataset_name] = {
            'f1': result.get('token_f1', 0.0),
            'precision': result.get('token_precision', 0.0),
            'recall': result.get('token_recall', 0.0),
            'accuracy': result.get('token_accuracy', 0.0),
            'pr_auc': result.get('word_pr_auc', 0.0)  # Added by word-level PR curve
        }
        
        # Sample-level metrics (if available)
        sample_metrics = result.get('sample_level_metrics', {})
        if sample_metrics:
            combined_sample_metrics[dataset_name] = {
                'f1': sample_metrics.get('f1', 0.0),
                'precision': sample_metrics.get('precision', 0.0),
                'recall': sample_metrics.get('recall', 0.0),
                'accuracy': sample_metrics.get('accuracy', 0.0),
                'pr_auc': sample_metrics.get('pr_auc', 0.0)
            }
        else:
            # Fallback to default values if sample-level metrics not available
            combined_sample_metrics[dataset_name] = {
                'f1': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'accuracy': 0.0,
                'pr_auc': 0.0
            }
        
        # Collect PR curve data (if available)
        combined_pr_data['word']['auc'].append(combined_word_metrics[dataset_name]['pr_auc'])
        combined_pr_data['sample']['auc'].append(combined_sample_metrics[dataset_name]['pr_auc'])
    
    logger.info("Creating combined metrics report...")
    
    # Create combined metrics report
    combined_report = {
        'evaluation_summary': {
            'datasets_evaluated': all_datasets,
            'total_datasets': len(all_datasets),
            'evaluation_timestamp': datetime.now().isoformat()
        },
        'individual_results': {
            dataset: {
                'sample_metrics': combined_sample_metrics[dataset],
                'word_metrics': combined_word_metrics[dataset]
            }
            for dataset in all_datasets
        },
        'aggregate_metrics': {
            'sample_level': {
                'mean_f1': np.mean([combined_sample_metrics[d]['f1'] for d in all_datasets]),
                'mean_precision': np.mean([combined_sample_metrics[d]['precision'] for d in all_datasets]),
                'mean_recall': np.mean([combined_sample_metrics[d]['recall'] for d in all_datasets]),
                'mean_accuracy': np.mean([combined_sample_metrics[d]['accuracy'] for d in all_datasets]),
                'mean_pr_auc': np.mean([combined_sample_metrics[d].get('pr_auc', 0.0) for d in all_datasets])
            },
            'word_level': {
                'mean_f1': np.mean([combined_word_metrics[d]['f1'] for d in all_datasets]),
                'mean_precision': np.mean([combined_word_metrics[d]['precision'] for d in all_datasets]),
                'mean_recall': np.mean([combined_word_metrics[d]['recall'] for d in all_datasets]),
                'mean_accuracy': np.mean([combined_word_metrics[d]['accuracy'] for d in all_datasets]),
                'mean_pr_auc': np.mean([combined_word_metrics[d].get('pr_auc', 0.0) for d in all_datasets])
            }
        }
    }
    
    # Save combined metrics
    combined_metrics_file = os.path.join(combined_output_dir, 'combined_metrics.json')
    with open(combined_metrics_file, 'w') as f:
        json.dump(combined_report, f, indent=2)
    logger.info(f"Combined metrics saved to: {combined_metrics_file}")
    
    # Create combined plots
    logger.info("Creating combined plots...")
    
    # Combined bar chart for sample-level metrics
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Sample-level F1 scores
    sample_f1_scores = [combined_sample_metrics[d]['f1'] for d in all_datasets]
    ax1.bar(range(len(all_datasets)), sample_f1_scores, alpha=0.8, color='skyblue')
    ax1.set_title('Sample-Level F1 Scores by Dataset', fontsize=12, fontweight='bold')
    ax1.set_xlabel('Dataset')
    ax1.set_ylabel('F1 Score')
    ax1.set_xticks(range(len(all_datasets)))
    ax1.set_xticklabels([d.replace('_', '\n') for d in all_datasets], rotation=45, ha='right', fontsize=9)
    ax1.set_ylim(0, 1)
    ax1.grid(True, alpha=0.3)
    
    # Word-level F1 scores
    word_f1_scores = [combined_word_metrics[d]['f1'] for d in all_datasets]
    ax2.bar(range(len(all_datasets)), word_f1_scores, alpha=0.8, color='lightcoral')
    ax2.set_title('Word-Level F1 Scores by Dataset', fontsize=12, fontweight='bold')
    ax2.set_xlabel('Dataset')
    ax2.set_ylabel('F1 Score')
    ax2.set_xticks(range(len(all_datasets)))
    ax2.set_xticklabels([d.replace('_', '\n') for d in all_datasets], rotation=45, ha='right', fontsize=9)
    ax2.set_ylim(0, 1)
    ax2.grid(True, alpha=0.3)
    
    # Sample-level Precision-Recall AUC
    sample_pr_aucs = [combined_sample_metrics[d].get('pr_auc', 0.0) for d in all_datasets]
    ax3.bar(range(len(all_datasets)), sample_pr_aucs, alpha=0.8, color='lightgreen')
    ax3.set_title('Sample-Level PR-AUC by Dataset', fontsize=12, fontweight='bold')
    ax3.set_xlabel('Dataset')
    ax3.set_ylabel('PR-AUC')
    ax3.set_xticks(range(len(all_datasets)))
    ax3.set_xticklabels([d.replace('_', '\n') for d in all_datasets], rotation=45, ha='right', fontsize=9)
    ax3.set_ylim(0, 1)
    ax3.grid(True, alpha=0.3)
    
    # Word-level Precision-Recall AUC
    word_pr_aucs = [combined_word_metrics[d].get('pr_auc', 0.0) for d in all_datasets]
    ax4.bar(range(len(all_datasets)), word_pr_aucs, alpha=0.8, color='gold')
    ax4.set_title('Word-Level PR-AUC by Dataset', fontsize=12, fontweight='bold')
    ax4.set_xlabel('Dataset')
    ax4.set_ylabel('PR-AUC')
    ax4.set_xticks(range(len(all_datasets)))
    ax4.set_xticklabels([d.replace('_', '\n') for d in all_datasets], rotation=45, ha='right', fontsize=9)
    ax4.set_ylim(0, 1)
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    combined_metrics_plot = os.path.join(combined_output_dir, 'combined_metrics_comparison.png')
    plt.savefig(combined_metrics_plot, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"Combined metrics comparison plot saved to: {combined_metrics_plot}")
    
    # Log summary statistics
    logger.info("=== Combined Evaluation Summary ===")
    logger.info(f"Total datasets evaluated: {len(all_datasets)}")
    logger.info(f"Datasets: {', '.join(all_datasets)}")
    logger.info(f"Average Sample-level F1: {combined_report['aggregate_metrics']['sample_level']['mean_f1']:.4f}")
    logger.info(f"Average Word-level F1: {combined_report['aggregate_metrics']['word_level']['mean_f1']:.4f}")
    logger.info(f"Average Sample-level PR-AUC: {combined_report['aggregate_metrics']['sample_level']['mean_pr_auc']:.4f}")
    logger.info(f"Average Word-level PR-AUC: {combined_report['aggregate_metrics']['word_level']['mean_pr_auc']:.4f}")
    
    return combined_report


def run_multi_file_evaluation(model_path, tokenizer_path, data_directory, 
                              model_name='xlm-roberta-base',
                              batch_size=16, max_length=512, overlap=256, output_dir='data/logs',
                              loss_type='standard', dropout=0.1, save_predictions=True,
                              training_dataset_name=None, token_threshold=None):
    """Run evaluation on all JSONL files in a directory"""
    logger = logging.getLogger(__name__)
    
    # Find all JSONL files in the directory
    jsonl_files = find_jsonl_files(data_directory)
    
    if not jsonl_files:
        logger.error(f"No JSONL files found in directory: {data_directory}")
        return None
    
    logger.info(f"Found {len(jsonl_files)} JSONL files to evaluate:")
    for file_path in jsonl_files:
        logger.info(f"  - {os.path.basename(file_path)}")
    
    # Create main output directory
    os.makedirs(output_dir, exist_ok=True)
    
    individual_results = {}
    
    # Run evaluation on each file
    for file_path in jsonl_files:
        dataset_name = os.path.splitext(os.path.basename(file_path))[0]
        logger.info(f"\n=== Evaluating {dataset_name} ===")
        
        # Create subdirectory for this dataset
        dataset_output_dir = os.path.join(output_dir, dataset_name)
        os.makedirs(dataset_output_dir, exist_ok=True)
        
        try:
            # Run evaluation for this file
            metrics = run_evaluation(
                model_path=model_path,
                tokenizer_path=tokenizer_path,
                data_path=file_path,
                model_name=model_name,
                batch_size=batch_size,
                max_length=max_length,
                overlap=overlap,
                output_dir=dataset_output_dir,
                loss_type=loss_type,
                dropout=dropout,
                save_predictions=save_predictions,
                training_dataset_name=training_dataset_name,
                token_threshold=token_threshold
            )
            
            if metrics is not None:
                individual_results[dataset_name] = metrics
                logger.info(f"✅ Successfully evaluated {dataset_name}")
            else:
                logger.error(f"❌ Failed to evaluate {dataset_name}")
                
        except Exception as e:
            logger.error(f"❌ Error evaluating {dataset_name}: {e}")
            continue
    
    # Combine results and create combined plots
    if individual_results:
        logger.info(f"\n=== Combining Results from {len(individual_results)} Datasets ===")
        combined_report = combine_metrics_and_plots(individual_results, output_dir)
        
        # Also save individual results summary
        individual_summary_file = os.path.join(output_dir, 'individual_results_summary.json')
        with open(individual_summary_file, 'w') as f:
            json.dump(individual_results, f, indent=2)
        logger.info(f"Individual results summary saved to: {individual_summary_file}")
        
        return combined_report
    else:
        logger.error("No datasets were successfully evaluated")
        return None


if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Evaluate instruction classification model')
    
    # Required parameters
    parser.add_argument('--run_dir', type=str, required=True,
                        help='Path to training run directory containing models/ subfolder')
    parser.add_argument('--data_path', type=str, required=True,
                        help='Path to the evaluation data (JSONL file or directory containing JSONL files)')
    
    # Optional parameters
    parser.add_argument('--model_name', type=str, default='xlm-roberta-base',
                        choices=['modern-bert-base', 'modern-bert-large', 'xlm-roberta-base', 'xlm-roberta-large'],
                        help='Pre-trained model name - choose from: modern-bert-base, modern-bert-large, xlm-roberta-base, xlm-roberta-large (default: xlm-roberta-base)')
    parser.add_argument('--batch_size', type=int, default=16,
                        help='Evaluation batch size (default: 16)')
    parser.add_argument('--max_length', type=int, default=512,
                        help='Maximum sequence length (default: 512)')
    parser.add_argument('--overlap', type=int, default=100,
                        help='Overlap between sliding windows (default: 100)')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='Dropout rate used in the trained model (default: 0.1)')
    parser.add_argument('--loss_type', type=str, default='standard',
                        choices=['standard', 'weighted_ce', 'focal'],
                        help='Loss function type used during training (default: standard)')
    parser.add_argument('--save_predictions', action='store_true', default=True,
                        help='Save predictions as JSONL file (default: True)')
    parser.add_argument('--no_save_predictions', dest='save_predictions', action='store_false',
                        help='Disable saving predictions')
    parser.add_argument('--run_samples', action='store_true',
                        help='Run sample predictions after evaluation')
    parser.add_argument('--training_dataset_name', type=str, default=None,
                        help='Name or description of the training dataset used to train the model')
    parser.add_argument('--eval_folder_name', type=str, default=None,
                        help='Custom name for evaluation output folder (default: uses timestamp eval_run_YYYYMMDD_HHMMSS)')
    parser.add_argument('--token_threshold', type=float, default=None,
                        help='Token-level threshold for binary classification (0.0-1.0). If provided, tokens with probability >= threshold are classified as 1, otherwise 0. If not provided, uses argmax (default behavior)')

    
    args = parser.parse_args()
    
    # Map friendly model name to actual HuggingFace model name
    actual_model_name = map_model_name(args.model_name)
    
    # Derive model and tokenizer paths from run_dir
    models_dir = os.path.join(args.run_dir, 'models')
    model_path = os.path.join(models_dir, 'best_instruction_classifier.pth')
    tokenizer_path = os.path.join(models_dir, 'best_instruction_classifier_tokenizer')
    
    # Verify that the model and tokenizer exist
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}")
        exit(1)
    if not os.path.exists(tokenizer_path):
        print(f"Error: Tokenizer directory not found at {tokenizer_path}")
        exit(1)
    
    # Create evaluation directory for this evaluation run inside the training run directory
    if args.eval_folder_name:
        # Use custom folder name (can overwrite existing)
        eval_run_dir = os.path.join(args.run_dir, args.eval_folder_name)
        # Remove existing directory if it exists to ensure clean overwrite
        if os.path.exists(eval_run_dir):
            shutil.rmtree(eval_run_dir)
        os.makedirs(eval_run_dir, exist_ok=True)
    else:
        # Use timestamp-based folder name
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        eval_run_dir = os.path.join(args.run_dir, f'eval_run_{timestamp}')
        os.makedirs(eval_run_dir, exist_ok=True)
    
    # Setup logging with the new timestamped directory
    logger, log_file = setup_logging(log_dir=eval_run_dir)
    logger.info("=== Instruction Classification Evaluation Started ===")
    logger.info(f"Training run directory: {args.run_dir}")
    logger.info(f"Evaluation run directory: {eval_run_dir}")
    
    # Check GPU availability
    check_gpu_availability()
    
    # Log evaluation configuration
    logger.info("Evaluation configuration:")
    logger.info(f"Run directory: {args.run_dir}")
    logger.info(f"Model path: {model_path}")
    logger.info(f"Tokenizer path: {tokenizer_path}")
    logger.info(f"Data path: {args.data_path}")
    logger.info(f"Model name: {args.model_name} -> {actual_model_name}")
    logger.info(f"Batch size: {args.batch_size}")
    logger.info(f"Max length: {args.max_length}")
    logger.info(f"Overlap: {args.overlap}")
    logger.info(f"Dropout: {args.dropout}")
    logger.info(f"Evaluation output directory: {eval_run_dir}")
    logger.info(f"Loss type: {args.loss_type}")
    logger.info(f"Save predictions: {args.save_predictions}")
    logger.info(f"Training dataset name: {args.training_dataset_name or 'Not specified'}")
    logger.info(f"Token threshold: {args.token_threshold or 'Not specified (using argmax)'}")
    
    # Determine if data_path is a file or directory and run appropriate evaluation
    try:
        if os.path.isdir(args.data_path):
            logger.info(f"Data path is a directory. Running multi-file evaluation...")
            logger.info(f"Directory: {args.data_path}")
            metrics = run_multi_file_evaluation(
                model_path=model_path,
                tokenizer_path=tokenizer_path,
                data_directory=args.data_path,
                model_name=actual_model_name,
                batch_size=args.batch_size,
                max_length=args.max_length,
                overlap=args.overlap,
                output_dir=eval_run_dir,
                loss_type=args.loss_type,
                dropout=args.dropout,
                save_predictions=args.save_predictions,
                training_dataset_name=args.training_dataset_name,
                token_threshold=args.token_threshold
            )
        elif os.path.isfile(args.data_path):
            logger.info(f"Data path is a file. Running single-file evaluation...")
            logger.info(f"File: {args.data_path}")
            metrics = run_evaluation(
                model_path=model_path,
                tokenizer_path=tokenizer_path,
                data_path=args.data_path,
                model_name=actual_model_name,
                batch_size=args.batch_size,
                max_length=args.max_length,
                overlap=args.overlap,
                output_dir=eval_run_dir,
                loss_type=args.loss_type,
                dropout=args.dropout,
                save_predictions=args.save_predictions,
                training_dataset_name=args.training_dataset_name,
                token_threshold=args.token_threshold
            )
        else:
            logger.error(f"Data path does not exist: {args.data_path}")
            exit(1)
        
        if metrics is None:
            logger.error("Evaluation failed!")
            exit(1)
        
        # Run sample predictions if requested
        if args.run_samples:
            run_sample_predictions(
                model_path=model_path,
                tokenizer_path=tokenizer_path,
                model_name=actual_model_name,
                loss_type=args.loss_type,
                dropout=args.dropout
            )
        
        logger.info("=== Evaluation Complete ===")
        logger.info(f"All logs saved to: {log_file}")
        logger.info(f"All outputs saved to: {eval_run_dir}")
        
    except Exception as e:
        logger.error(f"Evaluation failed with error: {e}")
        exit(1) 