"""
Example: Using InterleavedKLTrainer for Mixed Batches

This example demonstrates how to use InterleavedKLTrainer to apply KL divergence
only to interleaved samples within each batch, while standard samples only receive
the base SFT loss.

The key is to create a dataset that provides an 'is_interleaved' field indicating
which samples should receive the KL regularization term.
"""

import torch
from datasets import Dataset, concatenate_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer
from ldifs.trainer import InterleavedKLTrainer


# ============================================================================
# Option 1: Dataset with 'is_interleaved' column
# ============================================================================

def create_dataset_with_interleaved_flag():
    """
    Create a dataset where some samples are marked as 'interleaved'.
    
    The dataset should have an 'is_interleaved' column (int or bool) that
    indicates which samples should receive KL regularization.
    """
    # Standard training samples (no KL regularization)
    standard_data = {
        'text': [
            'Question: What is 2+2? Answer: 4',
            'Question: What is the capital of France? Answer: Paris',
        ],
        'is_interleaved': [0, 0]  # 0 = standard sample, no KL
    }
    
    # Interleaved samples (will receive KL regularization)
    interleaved_data = {
        'text': [
            'Context: The sky is blue.\nQuestion: What color is the sky? Answer: Blue',
            'Context: Cats are mammals.\nQuestion: Are cats mammals? Answer: Yes',
        ],
        'is_interleaved': [1, 1]  # 1 = interleaved sample, apply KL
    }
    
    standard_dataset = Dataset.from_dict(standard_data)
    interleaved_dataset = Dataset.from_dict(interleaved_data)
    
    # Combine datasets - batches will have mixed samples
    combined_dataset = concatenate_datasets([standard_dataset, interleaved_dataset])
    combined_dataset = combined_dataset.shuffle(seed=42)
    
    return combined_dataset


# ============================================================================
# Option 2: Custom Collator that Adds 'is_interleaved' Field
# ============================================================================

class InterleavedDataCollator:
    """
    Custom data collator that ensures 'is_interleaved' is passed to the model.
    
    Use this if your dataset has an 'is_interleaved' column that needs to be
    included in the batch dictionary.
    """
    
    def __init__(self, base_collator):
        """
        Args:
            base_collator: The base data collator (e.g., DataCollatorForLanguageModeling)
        """
        self.base_collator = base_collator
    
    def __call__(self, features):
        # Extract is_interleaved flags before calling base collator
        is_interleaved = [f.pop('is_interleaved', 0) for f in features]
        
        # Call base collator to handle tokenization, padding, etc.
        batch = self.base_collator(features)
        
        # Add is_interleaved as a tensor to the batch
        batch['is_interleaved'] = torch.tensor(is_interleaved, dtype=torch.float)
        
        return batch


# ============================================================================
# Option 3: Separate Datasets (Automatically Tagged)
# ============================================================================

def create_separate_datasets():
    """
    Create separate datasets and tag them during dataset creation.
    This is useful when you have distinct data sources.
    """
    # Standard SFT dataset
    standard_dataset = Dataset.from_dict({
        'text': [
            'Question: What is 2+2? Answer: 4',
            'Question: What is the capital of France? Answer: Paris',
        ]
    })
    # Add flag: these are NOT interleaved
    standard_dataset = standard_dataset.map(lambda x: {'is_interleaved': 0})
    
    # Interleaved dataset (e.g., data with context from retrieval)
    interleaved_dataset = Dataset.from_dict({
        'text': [
            'Context: The sky is blue.\nQuestion: What color is the sky? Answer: Blue',
            'Context: Cats are mammals.\nQuestion: Are cats mammals? Answer: Yes',
        ]
    })
    # Add flag: these ARE interleaved
    interleaved_dataset = interleaved_dataset.map(lambda x: {'is_interleaved': 1})
    
    # Combine and shuffle
    combined = concatenate_datasets([standard_dataset, interleaved_dataset])
    combined = combined.shuffle(seed=42)
    
    return combined


# ============================================================================
# Training Example
# ============================================================================

def train_with_interleaved_kl():
    """
    Example training script using InterleavedKLTrainer.
    """
    # Load model and tokenizer
    model_name = "gpt2"  # Replace with your model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Load model to train
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Load frozen reference model (for KL divergence)
    frozen_model = AutoModelForCausalLM.from_pretrained(model_name)
    frozen_model.eval()
    for param in frozen_model.parameters():
        param.requires_grad = False
    
    # Create dataset with interleaved flags
    dataset = create_dataset_with_interleaved_flag()
    
    # Tokenize dataset
    def tokenize_function(examples):
        result = tokenizer(
            examples['text'],
            truncation=True,
            max_length=512,
            padding=False
        )
        # Keep the is_interleaved flag
        result['is_interleaved'] = examples['is_interleaved']
        return result
    
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['text']
    )
    
    # Set up data collator
    from transformers import DataCollatorForLanguageModeling
    base_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )
    collator = InterleavedDataCollator(base_collator)
    
    # Initialize trainer
    from transformers import TrainingArguments
    
    training_args = TrainingArguments(
        output_dir="./interleaved_kl_output",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        save_steps=1000,
        logging_steps=100,
    )
    
    trainer = InterleavedKLTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=collator,
        frozen_model=frozen_model,
        ldifs_lambda=10.0,  # Weight for KL divergence term
    )
    
    # Train
    trainer.train()
    
    print("\nTraining complete!")
    print("The model was trained with:")
    print("  - Base SFT loss on ALL samples")
    print("  - KL divergence regularization ONLY on interleaved samples (is_interleaved=1)")


# ============================================================================
# Verification: Check that masking works correctly
# ============================================================================

def verify_sample_masking():
    """
    Verify that the sample masking is working as expected in a batch.
    """
    print("\n" + "="*70)
    print("Verification: Sample Masking Behavior")
    print("="*70)
    
    # Create a batch with mixed samples
    batch_size = 4
    seq_len = 10
    vocab_size = 50257
    
    # Simulate a batch with 2 standard + 2 interleaved samples
    sample_mask = torch.tensor([0.0, 0.0, 1.0, 1.0])  # Last 2 are interleaved
    
    # Dummy logits and attention mask
    logits_trained = torch.randn(batch_size, seq_len, vocab_size)
    logits_frozen = torch.randn(batch_size, seq_len, vocab_size)
    attention_mask = torch.ones(batch_size, seq_len)
    
    from ldifs.kl_divergence import kl_divergence_loss
    
    # Compute KL with sample masking
    kl_masked = kl_divergence_loss(
        logits_trained, logits_frozen, attention_mask,
        sample_mask=sample_mask,
        reduction='batchmean'
    )
    
    # Compute KL without masking (all samples)
    kl_full = kl_divergence_loss(
        logits_trained, logits_frozen, attention_mask,
        sample_mask=None,
        reduction='batchmean'
    )
    
    # Compute KL with only interleaved samples
    kl_interleaved_only = kl_divergence_loss(
        logits_trained[2:], logits_frozen[2:], attention_mask[2:],
        sample_mask=None,
        reduction='batchmean'
    )
    
    print(f"\nKL with sample masking (only interleaved): {kl_masked.item():.4f}")
    print(f"KL without masking (all samples):          {kl_full.item():.4f}")
    print(f"KL computed only on interleaved samples:   {kl_interleaved_only.item():.4f}")
    print(f"\n✓ Masked and interleaved-only should be approximately equal")
    print(f"  Difference: {abs(kl_masked.item() - kl_interleaved_only.item()):.6f}")
    
    # Verify with all zeros mask (no KL should be applied)
    all_zeros_mask = torch.zeros(batch_size)
    kl_none = kl_divergence_loss(
        logits_trained, logits_frozen, attention_mask,
        sample_mask=all_zeros_mask,
        reduction='batchmean'
    )
    print(f"\nKL with all-zeros mask (should be ~0):     {kl_none.item():.4f}")
    print(f"✓ With all samples masked out, KL should be near zero")


# ============================================================================
# Main
# ============================================================================

if __name__ == "__main__":
    print("="*70)
    print("InterleavedKLTrainer Example")
    print("="*70)
    
    # Run verification first
    verify_sample_masking()
    
    print("\n" + "="*70)
    print("Dataset Creation Examples")
    print("="*70)
    
    # Example 1: Create dataset with flags
    dataset1 = create_dataset_with_interleaved_flag()
    print(f"\nDataset 1: {len(dataset1)} samples")
    print(f"Columns: {dataset1.column_names}")
    print(f"Sample: {dataset1[0]}")
    
    # Example 2: Separate datasets
    dataset2 = create_separate_datasets()
    print(f"\nDataset 2: {len(dataset2)} samples")
    print(f"Columns: {dataset2.column_names}")
    print(f"Sample: {dataset2[0]}")
    
    print("\n" + "="*70)
    print("To run full training, uncomment the line below:")
    print("="*70)
    # train_with_interleaved_kl()

