import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers.optimization import get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
import warnings
import logging
import os
import random
from datetime import datetime
import argparse
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')

# Disable tokenizer parallelism to avoid forking warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Import shared components from utils
from utils import (
    InstructionDataset, TransformerInstructionClassifier, collate_fn,
    set_random_seeds, setup_logging, check_gpu_availability, 
    calculate_class_weights, get_device, predict_instructions,
    AugmentationConfig, AdversarialAugmenter, ComprehensivePunctuationAugmenter, HTMLTagAugmenter, create_augmented_collate_fn
)

# Import evaluation functions from evaluate
from evaluate import evaluate_model, plot_precision_recall_curve, run_evaluation


def map_model_name(friendly_name):
    """Map friendly model name to actual HuggingFace model name"""
    model_mapping = {
        'modern-bert-base': 'answerdotai/ModernBERT-base',
        'modern-bert-large': 'answerdotai/ModernBERT-large',
        'xlm-roberta-base': 'xlm-roberta-base',
        'xlm-roberta-large': 'FacebookAI/xlm-roberta-large'
    }
    
    if friendly_name not in model_mapping:
        raise ValueError(f"Unsupported model name: {friendly_name}. Supported models: {list(model_mapping.keys())}")
    
    return model_mapping[friendly_name]

def plot_loss_curves(train_losses, val_losses, save_path, evaluations_per_epoch=4):
    """Plot training and validation loss curves with evaluation points"""
    logger = logging.getLogger(__name__)
    
    try:
        plt.figure(figsize=(12, 6))
        
        # Both training and validation losses are now aligned (same evaluation points)
        num_points = len(val_losses)
        if len(train_losses) != num_points:
            logger.warning(f"Training losses ({len(train_losses)}) and validation losses ({num_points}) length mismatch")
            # Pad or truncate to match
            min_len = min(len(train_losses), num_points)
            train_losses = train_losses[:min_len]
            val_losses = val_losses[:min_len]
            num_points = min_len
        
        # Create x-axis as evaluation points (fractional epochs)
        x_points = np.arange(num_points) / evaluations_per_epoch
        
        plt.plot(x_points, train_losses, 'b-', label='Training Loss', alpha=0.8, linewidth=2)
        plt.plot(x_points, val_losses, 'r-', label='Validation Loss', marker='o', linewidth=2, markersize=4, alpha=0.8)
        
        # Add vertical lines at epoch boundaries
        max_epoch = int(np.ceil(max(x_points))) if x_points.size > 0 else 1
        for epoch in range(1, max_epoch + 1):
            plt.axvline(x=epoch, color='gray', linestyle='--', alpha=0.3)
        
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Loss', fontsize=12)
        plt.title(f'Training and Validation Loss Over Time ({evaluations_per_epoch} evaluations per epoch)', fontsize=14)
        plt.legend(fontsize=11)
        plt.grid(True, alpha=0.3)
        
        # Add some statistics
        min_train_loss = min(train_losses) if train_losses else 0
        min_val_loss = min(val_losses) if val_losses else 0
        
        plt.text(0.02, 0.98, 
                f'Min Training Loss: {min_train_loss:.4f}\nMin Validation Loss: {min_val_loss:.4f}\nEvaluations: {num_points} points', 
                transform=plt.gca().transAxes, fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Loss curves saved to: {save_path} (showing {num_points} evaluation points)")
        
    except Exception as e:
        logger.error(f"Error creating loss curves plot: {e}")

def train_model(data_path: str, model_name: str = 'xlm-roberta-base', 
                epochs: int = 3, batch_size: int = 16, learning_rate: float = 2e-5,
                max_length: int = 512, overlap: int = 256, loss_type: str = 'weighted_ce',
                run_dir: str = 'data/logs', dropout: float = 0.1, custom_model_path: str = None,
                augmentation_config: AugmentationConfig = None, pretrained_model_path: str = None):
    """Main training function"""
    logger = logging.getLogger(__name__)
    
    # Set device - prioritize MPS for Apple Silicon Macs
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # Initialize tokenizer - use custom path if provided, otherwise use model_name
    if custom_model_path:
        logger.info(f"Loading tokenizer from custom path: {custom_model_path}")
        tokenizer = AutoTokenizer.from_pretrained(custom_model_path)
        actual_model_path = custom_model_path
    else:
        logger.info(f"Loading tokenizer from HuggingFace: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        actual_model_path = model_name
    
    # Load and split data with sliding windows
    logger.info("Loading and processing data...")
    full_dataset = InstructionDataset(data_path, tokenizer, max_length, is_training=True, 
                                    window_size=max_length, overlap=overlap)
    
    # Validate dataset size
    if len(full_dataset) < 10:
        logger.error(f"Dataset too small ({len(full_dataset)} samples). Need at least 10 samples.")
        logger.error("Consider using a larger dataset for meaningful training.")
        return None, None, None, [], [], 0.0, 0, None
    
    # Calculate class weights for handling imbalanced data
    logger.info("=== Calculating Class Weights for Imbalanced Data ===")
    class_weights = calculate_class_weights(full_dataset)
    class_weights = class_weights.to(device)  # Move to device
    
    # Split into train and validation BY SAMPLE_ID (not by individual windows)
    # This ensures all windows for a sample stay together in the same split
    
    # Get unique sample IDs from the dataset
    sample_ids = []
    for i in range(len(full_dataset)):
        window_data = full_dataset.processed_data[i]
        sample_id = window_data['sample_id']
        if sample_id not in sample_ids:
            sample_ids.append(sample_id)
    
    logger.info(f"Total unique samples: {len(sample_ids)}")
    
    # Split sample IDs (not windows) into train/val
    val_sample_ratio = 0.1  # 10% for validation
    val_sample_count = max(1, int(val_sample_ratio * len(sample_ids)))
    train_sample_count = len(sample_ids) - val_sample_count
    
    # Use train_test_split to split sample IDs
    from sklearn.model_selection import train_test_split
    train_sample_ids, val_sample_ids = train_test_split(
        sample_ids, test_size=val_sample_count, random_state=42, shuffle=True
    )
    
    logger.info(f"Train samples: {len(train_sample_ids)}, Validation samples: {len(val_sample_ids)}")
    
    # Create indices for windows based on sample ID membership
    train_indices = []
    val_indices = []
    
    for i in range(len(full_dataset)):
        window_data = full_dataset.processed_data[i]
        sample_id = window_data['sample_id']
        
        if sample_id in train_sample_ids:
            train_indices.append(i)
        elif sample_id in val_sample_ids:
            val_indices.append(i)
        else:
            logger.warning(f"Sample {sample_id} not found in train or val splits!")
    
    # Create subsets using the indices
    train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
    
    logger.info(f"Training windows: {len(train_dataset)} (from {len(train_sample_ids)} samples)")
    logger.info(f"Validation windows: {len(val_dataset)} (from {len(val_sample_ids)} samples)")
    logger.info(f"✅ SPLIT BY SAMPLE: All windows for each sample stay together in train/val")
    
    if len(train_dataset) < 5:
        logger.warning("Very small training set. Results may be unreliable.")
    
    # Setup augmentation if enabled
    augmentation_enabled = (augmentation_config and 
                           (augmentation_config.enable_char_aug or 
                            augmentation_config.enable_punctuation_aug or 
                            augmentation_config.enable_html_tag_aug))
    
    if augmentation_enabled:
        logger.info("=== Adversarial Data Augmentation Enabled ===")
        
        # Initialize augmenters list
        augmenters = []
        
        # Add character-level augmenters
        if augmentation_config.enable_punctuation_aug:
            logger.info("✓ Comprehensive Punctuation Augmentation")
            logger.info(f"  Instruction word augmentation probability: {augmentation_config.instruction_word_aug_prob}")
            logger.info(f"  Punctuation characters: {augmentation_config.punctuation_chars}")
            logger.info(f"  Character substitution probability: {augmentation_config.char_substitution_prob}")
            logger.info(f"  Case mixing probability: {augmentation_config.case_mixing_prob}")
            augmenters.append(ComprehensivePunctuationAugmenter(augmentation_config))
            
        elif augmentation_config.enable_char_aug:
            logger.info("✓ Legacy Character Augmentation")
            logger.info(f"  Character augmentation probability: {augmentation_config.char_aug_prob}")
            logger.info(f"  Max characters per word: {augmentation_config.max_chars_per_word}")
            logger.info(f"  Instruction boost factor: {augmentation_config.instruction_boost_factor}")
            augmenters.append(AdversarialAugmenter(augmentation_config))
        
        # Add HTML tag augmenter
        if augmentation_config.enable_html_tag_aug:
            logger.info("✓ HTML Tag Token-Level Augmentation")
            logger.info(f"  HTML tag augmentation probability: {augmentation_config.html_tag_aug_prob}")
            logger.info(f"  Curriculum start: {augmentation_config.html_tag_curriculum_start}")
            logger.info(f"  Curriculum end: {augmentation_config.html_tag_curriculum_end}")
            logger.info(f"  Instruction boost factor: {augmentation_config.html_tag_instruction_boost}")
            logger.info(f"  Max tags per sample: {augmentation_config.html_tag_max_per_sample}")
            augmenters.append(HTMLTagAugmenter(augmentation_config))
        
        # Store augmenters for training
        augmenter = augmenters  # Pass list of augmenters
        
        logger.info(f"Sample augmentation probability: {augmentation_config.aug_sample_prob}")
        logger.info(f"Validation augmentation probability: {augmentation_config.val_aug_prob}")
        
        if augmentation_config.curriculum_aug:
            logger.info(f"Curriculum learning enabled:")
            logger.info(f"  Start probability: {augmentation_config.curriculum_start_prob}")
            logger.info(f"  End probability: {augmentation_config.curriculum_end_prob}")
        else:
            logger.info("Curriculum learning disabled (fixed augmentation probability)")
        
        # We'll create the data loaders with dynamic collate functions during training
        train_loader = None
        val_loader = None
    else:
        logger.info("Adversarial data augmentation disabled - using standard training")
        augmenter = None
        
        # Create standard data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, 
            collate_fn=collate_fn, num_workers=0  # Set to 0 to use main process only for GPU compatibility
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, 
            collate_fn=collate_fn, num_workers=0  # Set to 0 to use main process only for GPU compatibility
        )
    
    # Initialize model with class weights - use custom path if provided
    model = TransformerInstructionClassifier(actual_model_path, class_weights=class_weights, loss_type=loss_type, dropout=dropout)
    model.to(device)
    
    # Load pretrained model if specified (for multi-stage training)
    if pretrained_model_path:
        logger.info(f"=== Loading Pretrained Model for Multi-Stage Training ===")
        logger.info(f"Loading pretrained weights from: {pretrained_model_path}")
        try:
            pretrained_state_dict = torch.load(pretrained_model_path, map_location=device)
            model.load_state_dict(pretrained_state_dict)
            logger.info("✅ Pretrained model loaded successfully")
            logger.info("Starting multi-stage fine-tuning...")
        except Exception as e:
            logger.error(f"Failed to load pretrained model: {e}")
            logger.error("Continuing with randomly initialized model...")
    else:
        logger.info("Training from scratch (no pretrained model specified)")
    
    # Setup optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    
    # Calculate total steps based on dataset size (since train_loader might be None with augmentation)
    if train_loader is not None:
        steps_per_epoch = len(train_loader)
    else:
        # Estimate steps per epoch for augmented training
        steps_per_epoch = (len(train_dataset) + batch_size - 1) // batch_size
    
    total_steps = steps_per_epoch * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=total_steps
    )
    
    # Initialize loss tracking
    train_losses = []
    val_losses = []
    
    # Initialize best model tracking
    best_f1 = 0.0
    best_model_state = None
    best_epoch = 0
    val_metrics = None  # Initialize to avoid undefined variable issues
    
    # Calculate equally spaced evaluation points within each epoch
    # We want 4 total evaluations per epoch: 3 equally spaced + 1 at the end
    evaluations_per_epoch = 4
    
    # Calculate how often to log during training (still 3 times per epoch)
    log_frequency = max(1, steps_per_epoch // 3)
    
    # Create models directory and setup paths for best model saving
    models_dir = os.path.join(run_dir, 'models')
    os.makedirs(models_dir, exist_ok=True)
    
    best_model_path = os.path.join(models_dir, 'best_instruction_classifier.pth')
    best_tokenizer_path = os.path.join(models_dir, 'best_instruction_classifier_tokenizer')
    
    # Training loop
    logger.info("=== Starting Training ===")
    if custom_model_path:
        logger.info(f"Model: Custom model from {custom_model_path}")
    else:
        logger.info(f"Model: {model_name}")
    logger.info(f"Loss Type: {loss_type}")
    logger.info(f"Epochs: {epochs}, Batch Size: {batch_size}, Learning Rate: {learning_rate}")
    logger.info(f"Logging frequency: every {log_frequency} batches")
    logger.info(f"Evaluations per epoch: 4 equally spaced evaluations")
    logger.info(f"Loss curves will show {evaluations_per_epoch * epochs} evaluation points total")
    
    for epoch in range(epochs):
        # Create data loaders for this epoch if using augmentation
        if augmenter is not None:
            # Create augmented collate functions for this epoch
            train_collate_fn = create_augmented_collate_fn(
                tokenizer=tokenizer, 
                augmenter=augmenter, 
                max_length=max_length,
                current_epoch=epoch, 
                total_epochs=epochs, 
                is_validation=False
            )
            # Always use clean collate function for validation (no augmentation)
            val_collate_fn = collate_fn
            
            # Create data loaders with epoch-specific augmentation
            current_train_loader = DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True, 
                collate_fn=train_collate_fn, num_workers=0
            )
            current_val_loader = DataLoader(
                val_dataset, batch_size=batch_size, shuffle=False, 
                collate_fn=val_collate_fn, num_workers=0
            )
            
            # Log current augmentation probabilities
            if isinstance(augmenter, list) and len(augmenter) > 0:
                current_train_aug_prob = augmenter[0].get_current_aug_prob(epoch, epochs, is_validation=False)
                current_val_aug_prob = augmenter[0].get_current_aug_prob(epoch, epochs, is_validation=True)
                
                # Log probabilities for all augmenters
                logger.info(f"Epoch {epoch+1} augmentation probabilities:")
                for i, aug in enumerate(augmenter):
                    aug_name = aug.__class__.__name__
                    if hasattr(aug, 'get_current_aug_prob'):
                        train_prob = aug.get_current_aug_prob(epoch, epochs, is_validation=False)
                        val_prob = aug.get_current_aug_prob(epoch, epochs, is_validation=True)
                        logger.info(f"  {aug_name}: Train={train_prob:.3f}, Val={val_prob:.3f}")
                    elif hasattr(aug, 'get_current_html_tag_prob'):
                        # Special handling for HTMLTagAugmenter
                        train_prob = aug.get_current_html_tag_prob(epoch, epochs, is_validation=False)
                        val_prob = aug.get_current_html_tag_prob(epoch, epochs, is_validation=True)
                        logger.info(f"  {aug_name}: Train={train_prob:.3f}, Val={val_prob:.3f}")
            else:
                current_train_aug_prob = augmenter.get_current_aug_prob(epoch, epochs, is_validation=False)
                current_val_aug_prob = augmenter.get_current_aug_prob(epoch, epochs, is_validation=True)
                logger.info(f"Epoch {epoch+1}: Train aug prob: {current_train_aug_prob:.3f}, Val aug prob: {current_val_aug_prob:.3f}")
        else:
            # Use the static data loaders
            current_train_loader = train_loader
            current_val_loader = val_loader
        
        # Calculate evaluation points for this epoch
        eval_points_per_epoch = []
        total_batches = len(current_train_loader)
        for i in range(1, evaluations_per_epoch + 1):
            eval_point = int((i * total_batches) / evaluations_per_epoch)
            eval_points_per_epoch.append(eval_point)
        model.train()
        total_loss = 0
        valid_batches = 0
        epoch_train_losses = []
        
        progress_bar = tqdm(current_train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs['loss']
            
            # Skip batch if loss is NaN
            if torch.isnan(loss):
                logger.warning(f"Skipping batch with NaN loss at epoch {epoch+1}, batch {batch_idx}")
                continue
            
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            valid_batches += 1
            epoch_train_losses.append(loss.item())
            progress_bar.set_postfix({'loss': loss.item()})
            
            # Log training loss periodically
            if (batch_idx + 1) % log_frequency == 0 or (batch_idx + 1) == len(current_train_loader):
                avg_recent_loss = np.mean(epoch_train_losses[-log_frequency:])
                logger.info(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(current_train_loader)} - Avg Loss: {avg_recent_loss:.4f}")
            
            # Evaluate model at pre-calculated equally spaced points
            if (batch_idx + 1) in eval_points_per_epoch:
                logger.info(f"=== Evaluation at Epoch {epoch+1}, Batch {batch_idx+1} ===")
                val_metrics = evaluate_model(model, current_val_loader, device, tokenizer, max_length, save_predictions=False)
                val_losses.append(val_metrics['validation_loss'])
                
                # Record training loss at this evaluation point
                # Calculate how many batches to average over for this evaluation point
                eval_point_index = eval_points_per_epoch.index(batch_idx + 1)
                if eval_point_index == 0:
                    # First evaluation point - average from start
                    batches_to_average = batch_idx + 1
                else:
                    # Subsequent evaluation points - average since last evaluation
                    prev_eval_point = eval_points_per_epoch[eval_point_index - 1]
                    batches_to_average = (batch_idx + 1) - prev_eval_point
                
                current_train_loss = np.mean(epoch_train_losses[-batches_to_average:]) if epoch_train_losses else 0.0
                train_losses.append(current_train_loss)
                
                current_f1 = val_metrics['token_f1']
                
                # Check if this is the best model so far
                if current_f1 > best_f1:
                    best_f1 = current_f1
                    best_model_state = model.state_dict().copy()
                    best_epoch = epoch + 1
                    
                    # Save the best model immediately (overwrite previous)
                    torch.save(model.state_dict(), best_model_path)
                    tokenizer.save_pretrained(best_tokenizer_path)
                    
                    logger.info(f"🎯 NEW BEST MODEL! F1-score: {best_f1:.4f} (Epoch {best_epoch}, Batch {batch_idx+1})")
                    logger.info(f"✅ Best model saved to: {best_model_path}")
                    logger.info(f"✅ Best tokenizer saved to: {best_tokenizer_path}")
                else:
                    logger.info(f"Current F1: {current_f1:.4f}, Best F1: {best_f1:.4f} (no improvement)")
                
                logger.info(f"Validation Metrics (word-level via BERT approach):")
                logger.info(f"  Validation Loss: {val_metrics['validation_loss']:.4f}")
                logger.info(f"  Word Accuracy: {val_metrics['token_accuracy']:.4f}")
                logger.info(f"  Word Precision: {val_metrics['token_precision']:.4f}")
                logger.info(f"  Word Recall: {val_metrics['token_recall']:.4f}")
                logger.info(f"  Word F1: {val_metrics['token_f1']:.4f}")
                logger.info(f"  Sequence Accuracy: {val_metrics['sequence_accuracy']:.4f}")
                logger.info(f"  Total Sequences: {val_metrics['total_sequences']}")
                logger.info(f"  Best F1 so far: {best_f1:.4f} (Epoch {best_epoch})")
                logger.info("-" * 50)
                
                # Return to training mode
                model.train()
        
        if valid_batches > 0:
            avg_loss = total_loss / valid_batches
            logger.info(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f} (Valid batches: {valid_batches})")
        else:
            logger.error(f"Epoch {epoch+1} - No valid batches (all had NaN loss)")
            continue
        

        # Final evaluation on best model for PR curve generation
    if best_model_state is not None:
        logger.info("=== Final Evaluation on Best Model for PR Curve ===")
        model.load_state_dict(best_model_state)
        # For final evaluation, use clean validation data (no augmentation)
        if augmenter is not None:
            final_val_loader = DataLoader(
                val_dataset, batch_size=batch_size, shuffle=False, 
                collate_fn=collate_fn, num_workers=0
            )
        else:
            final_val_loader = val_loader
            
        final_metrics = evaluate_model(model, final_val_loader, device, tokenizer, max_length, save_predictions=False)
        
        logger.info("Best Model Final Performance (word-level via BERT approach):")
        logger.info(f"  Word Accuracy: {final_metrics['token_accuracy']:.4f}")
        logger.info(f"  Word Precision: {final_metrics['token_precision']:.4f}")
        logger.info(f"  Word Recall: {final_metrics['token_recall']:.4f}")
        logger.info(f"  Word F1: {final_metrics['token_f1']:.4f}")
        logger.info(f"  Sequence Accuracy: {final_metrics['sequence_accuracy']:.4f}")
        logger.info(f"  Best model saved at: {best_model_path}")
        logger.info(f"  Best tokenizer saved at: {best_tokenizer_path}")
        
        # Generate word-level precision-recall curve on best model
        if len(final_metrics['token_labels']) > 0 and len(final_metrics['token_probabilities']) > 0:
            pr_curve_path = os.path.join(run_dir, f'best_model_word_precision_recall_curve.png')
            os.makedirs(os.path.dirname(pr_curve_path), exist_ok=True)
            
            plot_precision_recall_curve(
                final_metrics['token_labels'], 
                final_metrics['token_probabilities'], 
                save_path=pr_curve_path
            )
            logger.info(f"Word-level Precision-Recall curve for best model saved to: {pr_curve_path}")
        else:
            logger.warning("Cannot generate word-level precision-recall curve: insufficient data")
        
        return model, tokenizer, final_metrics, train_losses, val_losses, best_f1, best_epoch, best_model_state
    else:
        logger.warning("No best model found, returning final model state")
        return model, tokenizer, val_metrics, train_losses, val_losses, best_f1, best_epoch, best_model_state

# Example usage
if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Train instruction classification model')
    
    # Data parameters
    parser.add_argument('--data_path', type=str, required=True,
                        help='Path to the training data (JSONL file)')
    parser.add_argument('--run_dir', type=str, default='data/logs',
                        help='Base directory where timestamped training run folder will be created (default: data/logs)')
    
    # Model parameters
    parser.add_argument('--model_name', type=str, default='xlm-roberta-base',
                        choices=['modern-bert-base', 'modern-bert-large', 'xlm-roberta-base', 'xlm-roberta-large'],
                        help='Pre-trained model name - choose from: modern-bert-base, modern-bert-large, xlm-roberta-base, xlm-roberta-large (default: xlm-roberta-base)')
    parser.add_argument('--custom_model_path', type=str, default=None,
                        help='Path to custom pre-trained model (e.g., domain-adapted model). If provided, this overrides --model_name')
    parser.add_argument('--max_length', type=int, default=512,
                        help='Maximum sequence length (default: 512)')
    parser.add_argument('--overlap', type=int, default=256,
                        help='Overlap between sliding windows (default: 100)')
    parser.add_argument('--dropout', type=float, default=0.1,
                        help='Dropout rate for the classifier (default: 0.1)')
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=5,
                        help='Number of training epochs (default: 5)')
    parser.add_argument('--batch_size', type=int, default=8,
                        help='Training batch size (default: 8)')
    parser.add_argument('--learning_rate', type=float, default=1e-5,
                        help='Learning rate (default: 1e-5)')
    parser.add_argument('--loss_type', type=str, default='weighted_ce',
                        choices=['standard', 'weighted_ce', 'focal'],
                        help='Loss function type (default: weighted_ce)')
    
    # Other parameters
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for reproducibility (default: 42)')
    parser.add_argument('--test_file', type=str, default=None,
                        help='Optional path to test file for evaluation after training (JSONL format)')

    # Adversarial Data Augmentation parameters
    parser.add_argument('--enable_char_aug', action='store_true',
                        help='Enable adversarial character insertion augmentation')
    parser.add_argument('--char_aug_prob', type=float, default=0.4,
                        help='Probability of applying character augmentation to each word (default: 0.4)')
    parser.add_argument('--max_chars_per_word', type=int, default=1,
                        help='Maximum number of characters to insert per word (default: 1)')
    parser.add_argument('--aug_sample_prob', type=float, default=0.4,
                        help='Probability of applying augmentation to each sample (default: 0.4)')
    
    # Curriculum learning parameters  
    parser.add_argument('--curriculum_aug', action='store_true',
                        help='Enable curriculum learning for augmentation (gradual ramp-up)')
    parser.add_argument('--curriculum_start_prob', type=float, default=0.1,
                        help='Starting augmentation probability for curriculum learning (default: 0.1)')
    parser.add_argument('--curriculum_end_prob', type=float, default=0.4,
                        help='Ending augmentation probability for curriculum learning (default: 0.4)')
    
    # Advanced augmentation parameters
    parser.add_argument('--instruction_boost_factor', type=float, default=1.5,
                        help='Multiplier for augmentation probability on instruction tokens (default: 1.5)')
    parser.add_argument('--val_aug_prob', type=float, default=0.1,
                        help='Probability of applying augmentation to validation samples (default: 0.1)')

    # New Comprehensive Punctuation Augmentation parameters
    parser.add_argument('--enable_punctuation_aug', action='store_true',
                        help='Enable comprehensive punctuation-based adversarial augmentation')
    parser.add_argument('--instruction_word_aug_prob', type=float, default=0.3,
                        help='Probability of applying augmentation to each instruction word (default: 0.3)')
    parser.add_argument('--punctuation_chars', type=str, default="-_.@#!$%^&*+=|\\:;\"'<>?/~`",
                        help='Punctuation characters to use for augmentation (default: all common ASCII punctuation)')
    parser.add_argument('--char_substitution_prob', type=float, default=0.1,
                        help='Probability of applying character substitution (o->0, e->3, etc.) (default: 0.1)')
    parser.add_argument('--case_mixing_prob', type=float, default=0.05,
                        help='Probability of applying case mixing augmentation (default: 0.05)')
    
    # HTML Tag Augmentation parameters
    parser.add_argument('--enable_html_tag_aug', action='store_true',
                        help='Enable HTML tag-based token-level adversarial augmentation')
    parser.add_argument('--html_tag_aug_prob', type=float, default=0.2,
                        help='Base probability of applying HTML tag augmentation to each word/phrase (default: 0.2)')
    parser.add_argument('--html_tag_curriculum_start', type=float, default=0.0,
                        help='Starting probability for HTML tag augmentation curriculum (default: 0.0)')
    parser.add_argument('--html_tag_curriculum_end', type=float, default=0.3,
                        help='Ending probability for HTML tag augmentation curriculum (default: 0.3)')
    parser.add_argument('--html_tag_instruction_boost', type=float, default=2.0,
                        help='Multiplier for HTML tag augmentation probability on instruction tokens (default: 2.0)')
    parser.add_argument('--html_tag_max_per_sample', type=int, default=3,
                        help='Maximum number of HTML tags to add per sample (default: 3)')
    
    # Multi-stage training parameters
    parser.add_argument('--pretrained_model_path', type=str, default=None,
                        help='Path to pretrained model for multi-stage training')
    parser.add_argument('--fine_tune_lr_factor', type=float, default=0.3,
                        help='Learning rate factor for fine-tuning (default: 0.3 = 30% of base LR)')

    
    args = parser.parse_args()
    
    # Determine model to use - custom path takes precedence
    if args.custom_model_path:
        actual_model_name = args.custom_model_path
        using_custom_model = True
    else:
        # Map friendly model name to actual HuggingFace model name
        actual_model_name = map_model_name(args.model_name)
        using_custom_model = False
    
    # Set random seeds for reproducibility
    set_random_seeds(args.seed)
    
    # Create timestamped directory for this training run
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    base_log_dir = args.run_dir
    run_dir = os.path.join(base_log_dir, f'train_run_{timestamp}')
    os.makedirs(run_dir, exist_ok=True)
    
    # Setup logging with the new timestamped directory
    logger, log_file = setup_logging(log_dir=run_dir)
    logger.info("=== Instruction Classification Training Started ===")
    logger.info(f"Training run directory: {run_dir}")
    
    # Check GPU availability first
    check_gpu_availability()
    
    # Create augmentation configuration
    augmentation_config = AugmentationConfig(
        enable_char_aug=args.enable_char_aug,
        char_aug_prob=args.char_aug_prob,
        max_chars_per_word=args.max_chars_per_word,
        aug_sample_prob=args.aug_sample_prob,
        curriculum_aug=args.curriculum_aug,
        curriculum_start_prob=args.curriculum_start_prob,
        curriculum_end_prob=args.curriculum_end_prob,
        instruction_boost_factor=args.instruction_boost_factor,
        val_aug_prob=args.val_aug_prob,
        # New punctuation augmentation parameters
        enable_punctuation_aug=args.enable_punctuation_aug,
        instruction_word_aug_prob=args.instruction_word_aug_prob,
        punctuation_chars=args.punctuation_chars,
        char_substitution_prob=args.char_substitution_prob,
        case_mixing_prob=args.case_mixing_prob,
        # HTML tag augmentation parameters
        enable_html_tag_aug=args.enable_html_tag_aug,
        html_tag_aug_prob=args.html_tag_aug_prob,
        html_tag_curriculum_start=args.html_tag_curriculum_start,
        html_tag_curriculum_end=args.html_tag_curriculum_end,
        html_tag_instruction_boost=args.html_tag_instruction_boost,
        html_tag_max_per_sample=args.html_tag_max_per_sample
    )
    
    # Log training configuration
    logger.info("Training configuration:")
    logger.info(f"Data path: {args.data_path}")
    logger.info(f"Run directory: {run_dir}")
    if using_custom_model:
        logger.info(f"Model: Custom model from {actual_model_name}")
    else:
        logger.info(f"Model name: {args.model_name} -> {actual_model_name}")
    logger.info(f"Epochs: {args.epochs}")
    logger.info(f"Batch size: {args.batch_size}")
    logger.info(f"Learning rate: {args.learning_rate}")
    logger.info(f"Max length: {args.max_length}")
    logger.info(f"Overlap: {args.overlap}")
    logger.info(f"Dropout: {args.dropout}")
    logger.info(f"Loss type: {args.loss_type}")
    logger.info(f"Random seed: {args.seed}")
    logger.info(f"Test file: {args.test_file if args.test_file else 'None (no test evaluation)'}")
    
    # Log augmentation configuration
    if args.enable_char_aug or args.enable_punctuation_aug:
        logger.info("=== Augmentation Configuration ===")
        if args.enable_punctuation_aug:
            logger.info(f"Punctuation-based augmentation: enabled")
            logger.info(f"Instruction word aug probability: {args.instruction_word_aug_prob}")
            logger.info(f"Punctuation characters: {args.punctuation_chars}")
            logger.info(f"Character substitution probability: {args.char_substitution_prob}")
            logger.info(f"Case mixing probability: {args.case_mixing_prob}")
        elif args.enable_char_aug:
            logger.info(f"Legacy character augmentation: enabled")
            logger.info(f"Character aug probability: {args.char_aug_prob}")
            logger.info(f"Max chars per word: {args.max_chars_per_word}")
            logger.info(f"Instruction boost factor: {args.instruction_boost_factor}")
        
        logger.info(f"Sample aug probability: {args.aug_sample_prob}")
        logger.info(f"Validation aug probability: {args.val_aug_prob}")
        logger.info(f"Curriculum learning: {'enabled' if args.curriculum_aug else 'disabled'}")
        if args.curriculum_aug:
            logger.info(f"  Start probability: {args.curriculum_start_prob}")
            logger.info(f"  End probability: {args.curriculum_end_prob}")
    else:
        logger.info("Augmentation: disabled")
    
    # Log multi-stage training configuration
    if args.pretrained_model_path:
        logger.info("=== Multi-Stage Training Configuration ===")
        logger.info(f"Pretrained model path: {args.pretrained_model_path}")
        logger.info(f"Fine-tune learning rate factor: {args.fine_tune_lr_factor}")
        actual_learning_rate = args.learning_rate * args.fine_tune_lr_factor
        logger.info(f"Effective learning rate: {actual_learning_rate}")
    else:
        actual_learning_rate = args.learning_rate
        logger.info("Multi-stage training: disabled (training from scratch)")
    
    # Train the model with the run directory for saving outputs
    if using_custom_model:
        model, tokenizer, metrics, train_losses, val_losses, best_f1, best_epoch, best_model_state = train_model(
            data_path=args.data_path,
            model_name=actual_model_name,  # This will be the custom path
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=actual_learning_rate,
            max_length=args.max_length,
            overlap=args.overlap,
            loss_type=args.loss_type,
            run_dir=run_dir,  # Pass the run directory to train_model
            dropout=args.dropout,
            custom_model_path=actual_model_name,
            augmentation_config=augmentation_config,
            pretrained_model_path=args.pretrained_model_path
        )
    else:
        model, tokenizer, metrics, train_losses, val_losses, best_f1, best_epoch, best_model_state = train_model(
            data_path=args.data_path,
            model_name=actual_model_name,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=actual_learning_rate,
            max_length=args.max_length,
            overlap=args.overlap,
            loss_type=args.loss_type,
            run_dir=run_dir,  # Pass the run directory to train_model
            dropout=args.dropout,
            augmentation_config=augmentation_config,
            pretrained_model_path=args.pretrained_model_path
        )
    
    # Check if training was successful
    if model is None or tokenizer is None or metrics is None:
        logger.error("Training failed. Exiting...")
        exit(1)
    
    # Best model has already been loaded and evaluated inside train_model function
    
    # Log final results (final evaluation will happen inside train_model)
    logger.info("=== Training Completed Successfully ===")
    logger.info(f"  Best F1-score achieved: {best_f1:.4f} (Epoch {best_epoch})")
    
    # Plot and save loss curves in the run directory
    logger.info("=== Plotting Loss Curves ===")
    loss_curves_path = os.path.join(run_dir, f'loss_curves.png')
    plot_loss_curves(train_losses, val_losses, loss_curves_path, evaluations_per_epoch=4)
    
    # Run evaluation on test file if provided
    if args.test_file:
        logger.info("=== Running Evaluation on Test File ===")
        logger.info(f"Test file: {args.test_file}")
        
        # Create timestamped directory for this evaluation run inside the training run directory
        eval_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        eval_run_dir = os.path.join(run_dir, f'eval_run_{eval_timestamp}')
        os.makedirs(eval_run_dir, exist_ok=True)
        logger.info(f"Evaluation output directory: {eval_run_dir}")
        
        # Prepare paths for best model
        models_dir = os.path.join(run_dir, 'models')
        best_model_path = os.path.join(models_dir, 'best_instruction_classifier.pth')
        best_tokenizer_path = os.path.join(models_dir, 'best_instruction_classifier_tokenizer')
        
        try:
            # Run evaluation using the same parameters as training
            test_metrics = run_evaluation(
                model_path=best_model_path,
                tokenizer_path=best_tokenizer_path,
                data_path=args.test_file,
                model_name=actual_model_name,
                batch_size=args.batch_size,
                max_length=args.max_length,
                overlap=args.overlap,
                output_dir=eval_run_dir,  # Save evaluation outputs in separate eval_run directory
                loss_type=args.loss_type,
                dropout=args.dropout,
                save_predictions=True
            )
            
            if test_metrics is not None:
                logger.info("=== Test Evaluation Results ===")
                logger.info(f"Test Word Accuracy: {test_metrics['token_accuracy']:.4f}")
                logger.info(f"Test Word Precision: {test_metrics['token_precision']:.4f}")
                logger.info(f"Test Word Recall: {test_metrics['token_recall']:.4f}")
                logger.info(f"Test Word F1: {test_metrics['token_f1']:.4f}")
                logger.info(f"Test Sequence Accuracy: {test_metrics['sequence_accuracy']:.4f}")
                logger.info(f"Test Total Sequences: {test_metrics['total_sequences']}")
                
                # Log sample-level metrics if available
                if 'sample_level_metrics' in test_metrics:
                    sample_metrics = test_metrics['sample_level_metrics']
                    logger.info("=== Test Sample-Level Results ===")
                    logger.info(f"Sample Accuracy: {sample_metrics['accuracy']:.4f}")
                    logger.info(f"Sample Precision: {sample_metrics['precision']:.4f}")
                    logger.info(f"Sample Recall: {sample_metrics['recall']:.4f}")
                    logger.info(f"Sample F1: {sample_metrics['f1']:.4f}")
                
                logger.info("Test evaluation completed successfully!")
                logger.info(f"Test evaluation outputs saved to: {eval_run_dir}")
            else:
                logger.error("Test evaluation failed!")
                
        except Exception as e:
            logger.error(f"Error during test evaluation: {e}")
            logger.error("Continuing with training completion...")
    
    # Example prediction - device will be auto-detected
    logger.info("=== Sample Prediction ===")
    sample_text_1 = "Please click on the Submit button and then fill out the form."
    sample_text_2 = "I am a student. Please don't sit on the grass, Rita!"
    tokens_1, predictions_1 = predict_instructions(model, tokenizer, sample_text_1)
    tokens_2, predictions_2 = predict_instructions(model, tokenizer, sample_text_2)
    
    logger.info("Sample Prediction Results:")
    for token, pred in zip(tokens_1, predictions_1):
        label = "INSTRUCTION" if pred == 1 else "OTHER"
        logger.info(f"  {token}: {label}")
    for token, pred in zip(tokens_2, predictions_2):
        label = "INSTRUCTION" if pred == 1 else "OTHER"
        logger.info(f"  {token}: {label}")
    
    logger.info("=== Training Session Complete ===")
    logger.info(f"All logs saved to: {log_file}")
    logger.info(f"Training outputs saved to: {run_dir}")
    if args.test_file:
        logger.info("Test evaluation results saved in separate eval_run directory within the training run")