import torch
import numpy as np
import faiss
import math
import time
import re
from transformers import AutoModelForCausalLM, AutoTokenizer


# --- 1. Configuration ---
class Config:
    MODEL_NAME = "../Qwen3-0.6B"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_ITERATIONS = 128
    EXPLORATION_CONSTANT = 1.5
    K_EXPAND = 1
    SIMULATION_MAX_LENGTH = 160 # Increased to allow for longer thoughts
    ROLLOUT_MAX_LENGTH = 70
    ALPHA = 0.5
    BETA = 0.5
    # Re-balanced weights, removed length bonus
    W_COH = 0.6 # Coherence weight
    W_NOV = 0.4 # Novelty weight
    TARGET_LENGTH = 80

# --- 2. LLM Interface ---
class LLMInterface:
    def __init__(self, model_name, device):
        print(f"Loading model: {model_name} on {device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
            device_map="auto"
        ).eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def extract_vector(self, text: str) -> np.ndarray:
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            last_layer_hidden_states = outputs.hidden_states[-1]
            last_token_hidden_state = last_layer_hidden_states[0, -1, :]
            return last_token_hidden_state.cpu().to(torch.float32).numpy()

    def generate_chat_completion(self, prompt: str, max_length: int) -> list[str]:
        messages = [{"role": "user", "content": prompt}]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)

        generated_ids = self.model.generate(
            model_inputs.input_ids,
            max_new_tokens=max_length,
            do_sample=True, top_p=0.9, temperature=0.7,
            pad_token_id=self.tokenizer.eos_token_id
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        return response

    def get_sequence_log_prob(self, text: str) -> float:
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            input_ids = inputs.input_ids
            outputs = self.model(**inputs, labels=input_ids)
            return -outputs.loss.item()

# --- 3. MCTS Components ---
class MCTSNode:
    def __init__(self, text_fragment: str, parent=None, node_type: str = 'idea'):
        self.text_fragment = text_fragment
        self.parent = parent
        self.node_type = node_type
        self.children = []
        self.N = 0
        self.Q = 0.0
        self.is_terminal = False

    def get_full_text(self) -> str:
        path = []
        curr = self
        while curr is not None:
            path.append(curr.text_fragment)
            curr = curr.parent
        return " ".join(reversed(path)).strip()

    @property
    def value(self):
        return self.Q / self.N if self.N > 0 else 0

    def is_leaf(self):
        return len(self.children) == 0

class CG_MCTS:
    def __init__(self, llm_interface, v_target, novelty_db, config, initial_prompt: str):
        self.llm = llm_interface
        self.v_target = v_target
        self.novelty_db = novelty_db
        self.config = config
        self.root = MCTSNode(initial_prompt, node_type='root')

    def select(self, node: MCTSNode):
        while not node.is_leaf() and not node.is_terminal:
            best_child = max(node.children, key=lambda c: (c.Q / c.N if c.N > 0 else float('inf')) + self.config.EXPLORATION_CONSTANT * math.sqrt(math.log(node.N) / c.N if c.N > 0 else 0))
            node = best_child
        return node

    def expand(self, node: MCTSNode):
        if node.is_terminal or (node.N > 0 and not node.is_leaf()):
            return

        full_context = node.get_full_text()
        responses = self.llm.generate_chat_completion(full_context, self.config.SIMULATION_MAX_LENGTH)
        if not responses:
            node.is_terminal = True
            return

        response_text = responses[0]
        
        # Robust parsing logic for <think> tags
        current_parent = node
        last_idx = 0
        while last_idx < len(response_text):
            start_think = response_text.find('<think>', last_idx)
            
            if start_think == -1: # No more think tags
                idea_part = response_text[last_idx:].strip()
                if idea_part:
                    child_node = MCTSNode(idea_part, parent=current_parent, node_type='idea')
                    current_parent.children.append(child_node)
                    current_parent = child_node
                break

            # Add content before the <think> tag as an idea node
            idea_part = response_text[last_idx:start_think].strip()
            if idea_part:
                child_node = MCTSNode(idea_part, parent=current_parent, node_type='idea')
                current_parent.children.append(child_node)
                current_parent = child_node

            end_think = response_text.find('</think>', start_think)
            
            if end_think == -1: # Unclosed think tag
                think_part = response_text[start_think + len('<think>'):].strip()
                if think_part:
                    child_node = MCTSNode(think_part, parent=current_parent, node_type='think')
                    current_parent.children.append(child_node)
                    current_parent = child_node
                break # End of processing
            
            # Properly closed think tag
            think_part = response_text[start_think + len('<think>'):end_think].strip()
            if think_part:
                child_node = MCTSNode(think_part, parent=current_parent, node_type='think')
                current_parent.children.append(child_node)
                current_parent = child_node
            
            last_idx = end_think + len('</think>')

    def simulate_and_evaluate(self, node: MCTSNode):
        full_context = node.get_full_text()
        
        inputs = self.llm.tokenizer(full_context, return_tensors='pt').to(self.llm.device)
        rollout_ids = self.llm.model.generate(
            inputs.input_ids,
            max_new_tokens=self.config.ROLLOUT_MAX_LENGTH,
            pad_token_id=self.llm.tokenizer.eos_token_id
        )
        rollout_text = self.llm.tokenizer.decode(rollout_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)

        final_seq = (full_context + " " + rollout_text).strip()
        
        v_coherence = self.llm.get_sequence_log_prob(final_seq)
        seq_vec = self.llm.extract_vector(final_seq).astype('float32').reshape(1, -1)
        faiss.normalize_L2(seq_vec)
        similarities, _ = self.novelty_db.search(seq_vec, 1)
        v_novelty = 1.0 - similarities[0][0]

        v_coherence_norm = 1 / (1 + math.exp(-v_coherence))
        v_novelty_norm = np.clip(v_novelty, 0, 1)

        # Updated value function without length bonus
        value = (self.config.W_COH * v_coherence_norm +
                 self.config.W_NOV * v_novelty_norm)
        return value

    def backpropagate(self, node: MCTSNode, value: float):
        while node is not None:
            node.N += 1
            node.Q += value
            node = node.parent

    def search(self):
        for i in range(self.config.NUM_ITERATIONS):
            print(f"--- Iteration {i+1}/{self.config.NUM_ITERATIONS} ---")
            leaf_node = self.select(self.root)
            
            print(f"  Select: ...{leaf_node.get_full_text()[-80:]}")

            self.expand(leaf_node)

            node_to_sim = leaf_node
            if leaf_node.children:
                # After expansion, simulate from the first new child
                node_to_sim = leaf_node.children[0]
            
            print(f"  Simulate from ({node_to_sim.node_type}): ...{node_to_sim.text_fragment[-80:]}")
            value = self.simulate_and_evaluate(node_to_sim)
            print(f"  Value: {value:.4f}")

            print("  Backpropagate...")
            self.backpropagate(node_to_sim, value)
            print("-" * 25)

    def get_best_sequence(self):
        path_fragments = []
        node = self.root
        
        if node.node_type != 'think':
             path_fragments.append(node.text_fragment)

        while not node.is_leaf():
            if not node.children: break
            node = max(node.children, key=lambda n: n.N)
            if node.node_type != 'think':
                path_fragments.append(node.text_fragment)
            
        return " ".join(path_fragments).strip()

# --- 4. Helper Functions ---
def build_faiss_index(texts: list, llm_interface: LLMInterface):
    print(f"Building FAISS index for {len(texts)} texts...")
    vectors = np.array([llm_interface.extract_vector(text) for text in texts]).astype('float32')
    dimension = vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    faiss.normalize_L2(vectors)
    index.add(vectors)
    print("FAISS index built.")
    return index, vectors

# --- 5. Main Execution ---
if __name__ == "__main__":
    cfg = Config()
    llm_interface = LLMInterface(cfg.MODEL_NAME, cfg.DEVICE)

    anchor_concepts = ["Machine learning", "Deep learning", "Genomics", "Quantum computing", "Drug discovery"]
    novelty_documents = ["AlphaFold predicts protein structures.", "CRISPR-Cas9 is a gene-editing tool.", "BERT is a transformer-based model."]

    anchor_db, _ = build_faiss_index(anchor_concepts, llm_interface)
    novelty_db, _ = build_faiss_index(novelty_documents, llm_interface)

    problem_concept = "Discovering new materials for batteries"
    mechanism_concept = "Using generative adversarial networks (GANs)"
    
    initial_prompt = f"""You are an expert in materials science and AI. Your task is to generate a novel scientific idea combining '{problem_concept}' with '{mechanism_concept}'.

Use the <think> tag to reason about the problem, decompose it, and plan your approach. The thoughts inside the <think> tags will be used for your own reasoning but will be hidden in the final output.

Directly state the core concepts of the idea outside the <think> tags.

Begin."""

    print("\n" + "="*20 + " Starting CG-MCTS Search " + "="*20)
    
    # Simplified v_target as it's not the focus of current debugging
    v_target_dummy = np.zeros(llm_interface.model.config.hidden_size)
    mcts = CG_MCTS(llm_interface, v_target=v_target_dummy, novelty_db=novelty_db, config=cfg, initial_prompt=initial_prompt)

    start_time = time.time()
    mcts.search()
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.2f} seconds.")

    final_idea = mcts.get_best_sequence()

    print("\n" + "="*20 + " Final Generated Idea " + "="*20)
    print(final_idea)
