import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LLM:
    def __init__(self, model_name="Qwen3-0.6B"):
        print(f"Loading model: {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto").eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def generate(self, messages: list[dict], max_new_tokens: int = 50, do_sample: bool = True, top_k: int = 0, top_p: float = 1.0, temperature: float = 1.0) -> str:
        # Apply chat template to messages
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
        
        # Ensure pad_token_id is set for generation
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        output_ids = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            pad_token_id=self.tokenizer.pad_token_id
        )
        new_tokens = output_ids[0, input_ids.shape[-1]:]
        return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

def top_k_top_p_decode(llm: LLM, messages: list[dict], max_new_tokens: int = 4096, top_k: int = 50, top_p: float = 1.0) -> str:
    """
    Performs text generation using Top-K and Top-P (nucleus) sampling.
    """
    print(f"Generating with Top-K={top_k}, Top-P={top_p}...")
    generated_text = llm.generate(
        messages,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_k=top_k,
        top_p=top_p,
        temperature=0.7 # A common temperature for creative generation
    )
    return generated_text

if __name__ == "__main__":
    llm_model = LLM()
    
    # Original prompt converted to messages format
    initial_prompt_content = "My goal is to find a novel scientific idea for 'Discovering new materials for batteries' inspired by 'Using generative adversarial networks (GANs)'. Please begin generating a concept.\nIdea:"
    initial_messages = [
        {"role": "user", "content": initial_prompt_content}
    ]
    
    print("\n--- Top-K/Top-P Decoding Example ---")
    generated_idea = top_k_top_p_decode(llm_model, initial_messages)
    print("Generated Idea:")
    print(generated_idea)