import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from hr2r.model.recurrent_transformer import HR2RForCausalLM
import torch

from transformers.cache_utils import Cache

# Fix random seed for reproducibility
torch.manual_seed(42)

def greedy_generate(hr2r_model, tokenizer, input_ids, max_new_tokens=32768, verbose=False):
    """
    Greedy generation function for HR2R model.
    
    Args:
        hr2r_model: HR2RForCausalLM model instance
        tokenizer: tokenizer instance
        input_ids: torch.Tensor of shape (batch_size, seq_len)
        iter_count: torch.Tensor of shape (batch_size, seq_len) 
        max_new_tokens: maximum number of new tokens to generate
        verbose: whether to print generation progress
        
    Returns:
        generated_tokens: list of generated token IDs
        generated_text: decoded text of generated tokens
    """
    device = input_ids.device
    current_input_ids = input_ids.clone()
    cache = None
    output_tokens = []
    
    if verbose:
        print("=== Starting Greedy Generation ===")
    
    # Unified generation loop - first iteration processes the full sequence, subsequent ones process single tokens
    for step in range(max_new_tokens):
        # Forward pass
        outputs = hr2r_model(
            input_ids=current_input_ids,
            past_key_values=cache,
            use_cache=True
        )
        
        # Update cache
        cache = outputs.past_key_values
        
        # Get next token from the last position
        last_token_logits = outputs.logits[0, -1, :]
        next_token_id = torch.argmax(last_token_logits, dim=-1, keepdim=True)
        
        print(tokenizer.decode(next_token_id), end="", flush=True)

        # Check for EOS token
        if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id:
            if verbose:
                print("EOS token generated, stopping.")
            break
        
        output_tokens.append(next_token_id.item())
        
        # Prepare next iteration - switch to single token input after first pass
        current_input_ids = next_token_id.unsqueeze(0)  # (1, 1)

        if verbose and (step + 1) % 10 == 0:
            current_generated = tokenizer.decode(output_tokens)
            print(f"Step {step + 1}: Generated so far: '{current_generated[:100]}...'")
    
    generated_text = tokenizer.decode(output_tokens)
    
    if verbose:
        print("=== Generation Complete ===")
    
    return output_tokens, generated_text

model_name = "Qwen/Qwen3-0.6B"

# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="cuda:5",
    attn_implementation="sdpa"
)

# Create HR2R wrapper
model = base_model

# prepare the model input
prompt = "Give me a short introduction to large language model."
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=True # Switches between thinking and non-thinking modes. Default is True.
)

print("text:", text)

model_inputs = tokenizer([text], return_tensors="pt").to(base_model.device)
batch_size, seq_len = model_inputs.input_ids.shape

# Use the greedy generation function instead of model.generate()
output_tokens, generated_text = greedy_generate(
    hr2r_model=model,
    tokenizer=tokenizer, 
    input_ids=model_inputs.input_ids,
    max_new_tokens=32768,
    verbose=False
)

print("generated_text:", generated_text)