import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LLM:
    def __init__(self, model_name="Qwen3-0.6B"):
        print(f"Loading model: {model_name}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto").eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def generate(self, messages: list[dict], max_new_tokens: int = 50, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9) -> str:
        # Apply chat template to messages
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        
        input_ids = self.tokenizer.encode(text, return_tensors='pt').to(self.device)
        
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        output_ids = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=self.tokenizer.pad_token_id
        )
        new_tokens = output_ids[0, input_ids.shape[-1]:]
        return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

def chain_of_trees_decode(llm: LLM, initial_messages: list[dict], num_branches: int = 3, max_branch_tokens: int = 80, max_final_answer_tokens: int = 100) -> str:
    """
    Simulates Chain of Trees (CoT) decoding.
    Generates multiple reasoning branches, selects the best one, and then generates a final answer.
    """
    print(f"Generating with Chain of Trees (simulated, {num_branches} branches)...")
    
    branches = []
    for i in range(num_branches):
        # Create messages for each branch, maintaining history
        branch_messages = initial_messages + [
            {"role": "user", "content": f"Branch {i+1} - Let's explore a path:"}
        ]
        branch_thought = llm.generate(branch_messages, max_new_tokens=max_branch_tokens, temperature=0.7, top_p=0.9)
        branches.append({"messages": branch_messages, "thought": branch_thought, "score": 0})
        print(f"\nBranch {i+1} Thought:\n{branch_thought}")

    # Simplified evaluation: For demonstration, let's pick the longest branch as 'best'
    # In a real scenario, this would involve more sophisticated evaluation (e.g., relevance, coherence, novelty)
    best_branch = None
    max_len = -1
    for branch in branches:
        if len(branch["thought"]) > max_len:
            max_len = len(branch["thought"])
            best_branch = branch
    
    if not best_branch:
        return "No branches generated."

    print(f"\nSelected Best Branch (based on length):\n{best_branch["thought"]}")

    # Generate final answer based on the best branch
    final_answer_messages = best_branch["messages"] + [
        {"role": "assistant", "content": best_branch["thought"]},
        {"role": "user", "content": "Based on this exploration, the final idea is:"}
    ]
    final_answer = llm.generate(final_answer_messages, max_new_tokens=max_final_answer_tokens, temperature=0.7, top_p=0.9)
    print(f"\nFinal Answer:\n{final_answer}")
    
    return best_branch["thought"] + "\n" + final_answer

if __name__ == "__main__":
    llm_model = LLM()
    
    initial_prompt_content = "My goal is to find a novel scientific idea for 'Discovering new materials for batteries' inspired by 'Using generative adversarial networks (GANs)'. Please begin generating a concept.\nIdea:"
    initial_messages = [
        {"role": "user", "content": initial_prompt_content}
    ]
    
    print("\n--- Chain of Trees Decoding Example ---")
    generated_idea = chain_of_trees_decode(llm_model, initial_messages, num_branches=3, max_branch_tokens=120, max_final_answer_tokens=100)
    print("\nFull CoT (Trees) Generated Idea:")
    print(generated_idea)