"""
Minimal SFT Training Example with HR2R Model using Prefill Forward Pass

This example demonstrates how to train the HR2R model using supervised fine-tuning
with iteration counts (iter_count) for hierarchical recurrent processing.
"""

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

from hr2r.model.recurrent_transformer import HR2RForCausalLM
from hr2r.model.hr2r_config import HR2RConfig
from hr2r.model.utils import HR2RForCasualLM_generate

class SimpleDataset(Dataset):
    """Simple dataset for demonstration."""
    
    def __init__(self, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Simple training examples as (question, answer) tuples
        self.examples = [
            (
                "What is machine learning?",
                "Machine learning is a subset of artificial intelligence."
            ),
            (
                "Explain neural networks.",
                "Neural networks are computing systems inspired by biological neural networks."
            ),
            (
                "What is deep learning?",
                "Deep learning uses neural networks with multiple layers to learn patterns."
            ),
            (
                "Define AI.",
                "Artificial Intelligence is the simulation of human intelligence by machines."
            ),
        ]
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        question, answer = self.examples[idx]
        text = self.tokenizer.apply_chat_template(
            [
                {"role": "user", "content": question},
                {"role": "assistant", "content": answer}
            ],
            tokenize=False,
        )

        # Tokenize the text
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        input_ids = encoding.input_ids.squeeze()
        attention_mask = encoding.attention_mask.squeeze()
        
        # Create labels (same as input_ids for causal LM)
        labels = input_ids.clone()
        
        # Create iter_count - for this example, use 2 iterations for most tokens
        # and 1 iteration for padding tokens
        iter_count = torch.full_like(input_ids, 2)
        iter_count[attention_mask == 0] = 1  # Padding tokens get 1 iteration
        
        # Optionally, vary iter_count for some tokens (more iterations for "important" tokens)
        # For example, give more iterations to tokens after "Answer:"
        text_str = self.tokenizer.decode(input_ids, skip_special_tokens=True)
        if "Answer:" in text_str:
            # Find tokens after "Answer:" and give them more iterations
            answer_start = text_str.find("Answer:")
            if answer_start != -1:
                # This is a simplified approach - in practice you'd want more sophisticated logic
                iter_count[-10:] = 3  # Give last 10 tokens 3 iterations
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'iter_count': iter_count
        }


def collate_fn(batch):
    """Custom collate function to handle iter_count."""
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    iter_count = torch.stack([item['iter_count'] for item in batch])
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'iter_count': iter_count
    }


def train_step(model, batch, optimizer, device):
    """Single training step."""
    model.train()
    
    # Move batch to device
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    iter_count = batch['iter_count'].to(device)
    
    # Forward pass through HR2R model
    outputs = model(
        input_ids=input_ids,
        iter_count=torch.full_like(iter_count, -1),
        # iter_count=None,
        attention_mask=attention_mask,
        labels=labels,
        use_cache=True  # Must use cache during training
    )
    
    loss = outputs.loss
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


def main():
    """Main training function."""
    print("=== Minimal HR2R SFT Training Example ===")
    
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_name = "Qwen/Qwen3-0.6B"
    max_length = 128
    batch_size = 1
    num_epochs = 1
    learning_rate = 5e-5
    
    print(f"Using device: {device}")
    print(f"Model: {model_name}")
    
    # Load tokenizer and base model
    print("\nLoading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
        padding_side="right"
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map=device,
        attn_implementation="sdpa"
    )
    
    # Create HR2R wrapper
    print("Creating HR2R wrapper...")
    max_iter = 3
    hr2r_config = HR2RConfig(
        embedding_key="model.embed_tokens",
        max_iter=max_iter,
        iter_decider="EntropyIterDecider",
        iter_decider_kwargs={"threshold": 1.0, "max_iter": max_iter},
        input_updater="AdditiveUpdater",
        input_updater_kwargs={"topk": 100},
    )

    hr2r_model = HR2RForCausalLM(base_model=base_model, config=hr2r_config, device_map=device)
    
    # Create dataset and dataloader
    print("Preparing dataset...")
    dataset = SimpleDataset(tokenizer, max_length=max_length)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    # Setup optimizer
    optimizer = AdamW(hr2r_model.parameters(), lr=learning_rate)
    
    print(f"\nDataset size: {len(dataset)}")
    print(f"Batch size: {batch_size}")
    print(f"Number of epochs: {num_epochs}")
    print(f"Learning rate: {learning_rate}")
    
    # Training loop
    print("\n=== Starting Training ===")
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        with torch.autograd.set_detect_anomaly(True):
            for batch_idx, batch in enumerate(dataloader):
                loss = train_step(hr2r_model, batch, optimizer, device)
                total_loss += loss
                num_batches += 1
                
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss:.4f}")
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}")
    
    print("\n=== Training Complete ===")
    
    # Test the trained model with HR2R generation
    print("\n=== Testing Trained Model with HR2R Generation ===")
    hr2r_model.eval()
    
    test_prompt = "Question: What is AI? Answer:"
    test_input = tokenizer(test_prompt, return_tensors="pt").to(device)

    # Randomly assign iteration counts according to a probability list
    prob_list = [0.33, 0.33, 0.33]
    possible_counts = [1, 2, 3]
    flat_shape = test_input.input_ids.numel()
    chosen_counts = torch.multinomial(torch.tensor(prob_list), flat_shape, replacement=True)
    test_iter_count = torch.tensor([possible_counts[i] for i in chosen_counts], device=test_input.input_ids.device).view_as(test_input.input_ids)
    
    print(f"Test prompt: {test_prompt}")
    print(f"Input iter_count: {test_iter_count.squeeze().tolist()}")
    print("Generated response:")
    
    # Use HR2R generation with sampling
    # Note: For probabilistic iteration counts during generation, use RandomIterDecider instead
    with torch.no_grad():
        output_tokens, generated_text = HR2RForCasualLM_generate(
            hr2r_model=hr2r_model,
            tokenizer=tokenizer,
            model_inputs=test_input,
            iter_count=test_iter_count,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            top_k=40,
            min_p=0.0,
            verbose=True
        )
        
        print(f"\nGenerated text: {generated_text}")
        print(f"Number of generated tokens: {len(output_tokens)}")
    
    print("\nTraining example completed successfully!")


if __name__ == "__main__":
    main() 