"""Basic example of using CMI Loss for supervised fine-tuning."""

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from datasets import Dataset
from cmi_loss import CMILossTrainer, CMILossConfig


def create_sample_dataset(tokenizer, num_samples=100):
    """Create a sample dataset for demonstration."""
    # Sample data with thinking patterns
    data = []
    
    for i in range(num_samples // 2):
        # Benign samples with reasoning
        text = f"<think>Let me consider this question carefully...</think> The answer is {i}."
        data.append({"text": text, "sample_type": "benign"})
    
    for i in range(num_samples // 2):
        # Samples requiring safety consideration
        text = f"<think>I need to be careful here...</think> I cannot provide that information."
        data.append({"text": text, "sample_type": "harmful"})
    
    # Tokenize the dataset
    def tokenize_function(examples):
        outputs = tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=128,
        )
        outputs["labels"] = outputs["input_ids"].copy()
        outputs["sample_types"] = examples["sample_type"]
        return outputs
    
    dataset = Dataset.from_list(data)
    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    
    return tokenized_dataset


def main():
    # Model and tokenizer setup
    model_name = "gpt2"  # Use a small model for demonstration
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Add thinking tokens to vocabulary if not present
    special_tokens = {"additional_special_tokens": ["<think>", "</think>"]}
    tokenizer.add_special_tokens(special_tokens)
    
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    
    # Create dataset
    train_dataset = create_sample_dataset(tokenizer)
    
    # CMI Loss configuration
    cmi_config = CMILossConfig(
        cmi_lambda=-0.1,  # Negative value encourages reasoning
        cmi_lambda_start=-0.01,  # Start with smaller magnitude
        cmi_warmup_ratio=0.3,  # 30% warmup
        cmi_rampup_ratio=0.5,  # 50% rampup
        cmi_thinking_weight=0.1,  # Partial weight for thinking tokens
        cmi_loss_normalize=True,  # Normalize losses
        cmi_apply_to_harmful_only=False,  # Apply to all samples
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir="./cmi_loss_example",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=100,
        logging_steps=10,
        save_steps=500,
        evaluation_strategy="steps",
        eval_steps=100,
        logging_dir="./logs",
        report_to="none",  # Disable wandb/tensorboard for example
        max_steps=1000,  # Set max steps for CMI scheduling
    )
    
    # Data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )
    
    # Initialize trainer with CMI Loss
    trainer = CMILossTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        cmi_config=cmi_config,
    )
    
    # Train the model
    print("Starting training with CMI Loss...")
    trainer.train()
    
    # Save the model
    trainer.save_model("./cmi_loss_model")
    print("Training completed! Model saved to ./cmi_loss_model")
    
    # Test generation
    model.eval()
    test_input = "What is the meaning of life?"
    inputs = tokenizer(test_input, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=100,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    print(f"\nTest generation:\nInput: {test_input}\nOutput: {generated_text}")


if __name__ == "__main__":
    main()