import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LLM:
    def __init__(self, model_name="Qwen3-0.6B"):
        print(f"Loading model: {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto").eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def generate(self, messages: list[dict], max_new_tokens: int = 50, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9) -> str:
        # Apply chat template to messages
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
        
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        output_ids = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=self.tokenizer.pad_token_id
        )
        new_tokens = output_ids[0, input_ids.shape[-1]:]
        return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

def cot_decode(llm: LLM, initial_messages: list[dict], max_thought_tokens: int = 100, max_answer_tokens: int = 100) -> str:
    """
    Performs Chain of Thought (CoT) decoding.
    First, generates a reasoning path, then generates the final answer based on that reasoning.
    """
    print("Generating with Chain of Thought...")
    
    # Step 1: Generate thought
    thought_messages = initial_messages + [
        {"role": "user", "content": "Let's think step by step:"}
    ]
    thought = llm.generate(thought_messages, max_new_tokens=max_thought_tokens, temperature=0.7, top_p=0.9)
    print(f"\nThought:\n{thought}")

    # Step 2: Generate answer based on thought
    answer_messages = thought_messages + [
        {"role": "assistant", "content": thought},
        {"role": "user", "content": "Therefore, the idea is:"}
    ]
    answer = llm.generate(answer_messages, max_new_tokens=max_answer_tokens, temperature=0.7, top_p=0.9)
    print(f"\nAnswer:\n{answer}")
    
    return thought + "\n" + answer

if __name__ == "__main__":
    llm_model = LLM()
    
    initial_prompt_content = "My goal is to find a novel scientific idea for 'Discovering new materials for batteries' inspired by 'Using generative adversarial networks (GANs)'. Please begin generating a concept.\nIdea:"
    initial_messages = [
        {"role": "user", "content": initial_prompt_content}
    ]
    
    print("--- Chain of Thought Decoding Example ---")
    generated_idea = cot_decode(llm_model, initial_messages, max_thought_tokens=150, max_answer_tokens=100)
    print("\nFull CoT Generated Idea:")
    print(generated_idea)