#!/usr/bin/env python3
"""
Simple Example Usage of Info-Gain Sampler

This script demonstrates the simplest way to use the Info-Gain Sampler
for text generation with Masked Diffusion Models (MDMs) like LLaDA and Dream.
"""

import os
import sys
import torch
from transformers import AutoTokenizer, AutoModel

# Add project root to path
current_script_path = os.path.abspath(__file__)
scripts_dir = os.path.dirname(current_script_path)
project_root = os.path.dirname(scripts_dir)
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.generate import generate


def main():
    """
    Simple example: Generate text using Info-Gain Sampler with confidence heuristic
    """
    # Step 1: Load model and tokenizer
    model_name = "GSAI-ML/LLaDA-8B-Instruct"  # or 
    device = "cuda:0"
    
    # Detect model type
    model_name_lower = model_name.lower()
    is_dream = 'dream' in model_name_lower
    is_llada = 'llada' in model_name_lower or 'llama' in model_name_lower
    
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
    model.eval()
    
    # Auto-detect mask_id based on model type
    mask_id = None
    if hasattr(model, 'config') and hasattr(model.config, 'mask_token_id'):
        mask_id = model.config.mask_token_id
    elif getattr(tokenizer, 'mask_token_id', None) is not None:
        mask_id = tokenizer.mask_token_id
    
    # Set baseline file path based on model type
    if is_dream:
        baseline_name = os.path.join(project_root, "data", "baseline", "reference_corpus_dream.json")
    elif is_llada:
        baseline_name = os.path.join(project_root, "data", "baseline", "reference_corpus_llada.json")
    else:
        baseline_name = os.path.join(project_root, "data", "baseline", "reference_corpus.json")
    
    # Step 2: Prepare input prompt
    prompt_text = "What is the Infomation Gain in Decison Tree?"
    messages = [{"role": "user", "content": prompt_text}]
    prompt_str = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    prompt = tokenizer(prompt_str)['input_ids']
    prompt = torch.tensor(prompt).to(device).unsqueeze(0)
    
    # Step 3: Configure generation parameters
    gen_length = 256      # Length of generated sequence
    steps = 256           # Number of decoding steps
    block_length = 32     # Block size for decoding
    temperature = 0.7     # Sampling temperature
    
    # Step 4: Configure Info-Gain Sampler (using confidence heuristic)
    heuristic = 'confidence'  # Use confidence as the heuristic function
    candidate_number = 8      # Number of candidate actions (1 = greedy, >1 = Info-Gain mode)
    
    # Step 5: Generate text
    print("Generating text with Info-Gain Sampler...")
    
    with torch.no_grad():
        output = generate(
            model=model,
            prompt=prompt,
            steps=steps,
            gen_length=gen_length,
            block_length=block_length,
            temperature=temperature,
            candidate_number=candidate_number,
            position_temperature=0.3,  # Set to 0.3 for action sampling
            heuristic=heuristic,
            mask_id=mask_id,           # Auto-detected based on model type
            is_dream=is_dream,         # Auto-detected based on model name
            baseline_name=baseline_name,  # Baseline file path
            eos_penalty=1.0
        )
    
    # Step 6: Decode and display result
    generated_text = tokenizer.batch_decode(output[:, prompt.shape[1]:], skip_special_tokens=False)[0]
    print("\n" + "="*80)
    print("Generated Text:")
    print("="*80)
    print(generated_text)
    print("="*80)


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
