import torch
import numpy as np
import faiss
import math
import time
import re
import json
import random
from sklearn.cluster import KMeans
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- 1. Configuration ---
class Config:
    MODEL_NAME = "../Qwen3-0.6B"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_ITERATIONS = 64
    
    # MCTS Core
    EXPLORATION_CONSTANT = 1.5
    K_EXPAND = 3
    N_ROLLOUTS = 3
    
    # LLM Generation Lengths
    EXPAND_MAX_LENGTH = 1024  # NEW: Long length for deep thinking in expand step
    ROLLOUT_MAX_LENGTH = 128   # NEW: Short length for fast, direct rollouts
    THEME_GEN_MAX_LENGTH = 1024 # For theme generation and deconstruction

    # Automated Theme Generation
    NUM_CLUSTERS = 3
    
    # T-MCTS/PE Hyperparameters
    ALPHA_NOVELTY = 0.7
    W_DIR = 1.0
    W_COH = 0.5
    W_NOV = 0.3
    W_PROG = 0.2

# --- 2. Helper Functions ---
def parse_llm_json_output(response: str) -> dict | None:
    """Robustly parses JSON from LLM output that might include markdown."""
    try:
        match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL)
        if match:
            json_str = match.group(1)
            return json.loads(json_str)
        else:
            return json.loads(response)
    except (json.JSONDecodeError, IndexError):
        print(f"Warning: Failed to parse LLM response as JSON.")
        return None

# --- 3. LLM Interface ---
class LLMInterface:
    def __init__(self, model_name, device):
        print(f"Loading model: {model_name} on {device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
            device_map="auto"
        ).eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def get_vector(self, text: str) -> np.ndarray:
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            vector = outputs.hidden_states[-1][0, -1, :].cpu().to(torch.float32).numpy()
            del inputs, outputs
            return vector

    def generate_completion(self,messages:str, max_length: int, temperature: float = 0.7) -> str:
        model_inputs = self.tokenizer(messages, return_tensors="pt").to(self.device)
        generated_ids = self.model.generate(
                model_inputs.input_ids,
                attention_mask=model_inputs.attention_mask,
                max_new_tokens=max_length,
                do_sample=True, top_p=0.9, temperature=0.8,
                pad_token_id=self.tokenizer.eos_token_id
        )
        response = self.tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0]
        del model_inputs, generated_ids
        return response

    def generate_chat_completion(self, messages: list, max_length: int, temperature: float = 0.7, thinking=True) -> str:
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=thinking)
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
        generated_ids = self.model.generate(
            model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask,
            max_new_tokens=max_length,
            do_sample=True, top_p=0.9, temperature=temperature,
            pad_token_id=self.tokenizer.eos_token_id
        )
        response = self.tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0]
        del text, model_inputs, generated_ids
        return response

    def get_prob_and_vector(self, text: str) -> tuple[float, np.ndarray]:
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self.model(**inputs, labels=inputs.input_ids, output_hidden_states=True)
            log_prob = -outputs.loss.item()
            vector = outputs.hidden_states[-1][0, -1, :].cpu().to(torch.float32).numpy()
            del inputs, outputs
            return log_prob, vector

# --- 4. Automated Theme Generator ---
class AutomatedThemeGenerator:
    def __init__(self, llm_interface: LLMInterface, documents: list, vectors: np.ndarray, config: Config):
        print("\n--- Initializing Automated Theme Generator ---")
        self.llm = llm_interface
        self.documents = documents
        self.vectors = vectors
        self.config = config
        self.labels = None
        self.centroids = None
        self._build_knowledge_map()

    def _build_knowledge_map(self):
        print(f"Phase 0: Clustering {len(self.documents)} documents into {self.config.NUM_CLUSTERS} 'Conceptual Continents'...")
        kmeans = KMeans(n_clusters=self.config.NUM_CLUSTERS, random_state=42, n_init=10)
        self.labels = kmeans.fit_predict(self.vectors)
        self.centroids = kmeans.cluster_centers_
        print("Knowledge map built.")

    def generate_theme(self) -> str:
        print("\n--- Generating a new theme automatically ---")
        print("Phase 1 & 2: Sampling continents and finding concepts...")
        cluster_a_idx = random.randint(0, self.config.NUM_CLUSTERS - 1)
        distances = np.linalg.norm(self.centroids - self.centroids[cluster_a_idx], axis=1)
        sorted_indices = np.argsort(distances)
        num_clusters = self.config.NUM_CLUSTERS
        medium_dist_start = num_clusters // 3
        medium_dist_end = 2 * (num_clusters // 3)
        medium_distance_indices = sorted_indices[medium_dist_start:medium_dist_end]
        if not medium_distance_indices.any():
            medium_distance_indices = [i for i in sorted_indices if i != cluster_a_idx]
        cluster_b_idx = random.choice(medium_distance_indices)
        
        def get_representative_doc(cluster_idx):
            doc_indices = np.where(self.labels == cluster_idx)[0]
            return self.documents[random.choice(doc_indices)]

        concept_a = get_representative_doc(cluster_a_idx)
        concept_b = get_representative_doc(cluster_b_idx)

        print("Phase 3: Synthesizing a high-level theme...")
        prompt = f"""
        First, think step-by-step in a <think> block about how to creatively combine the two following scientific concepts. Then, output the final high-level, forward-looking research theme as a single, concise string.

        Concept 1: "{concept_a}"
        Concept 2: "{concept_b}"
        """
        response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.THEME_GEN_MAX_LENGTH, temperature=0.5)
        theme = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL).strip()
        return theme if theme else response

# --- 5. MCTS Components ---
class MCTSNode:
    def __init__(self, text_fragment: str, parent=None, thoughts: str = None):
        self.text_fragment = text_fragment
        self.parent = parent
        self.thoughts = thoughts
        self.children = []
        self.N = 0
        self.Q = 0.0
        self.is_terminal = False
        self._vector = None

    def get_vector(self, llm_interface: LLMInterface) -> np.ndarray:
        if self._vector is None and self.get_full_text():
            self._vector = llm_interface.get_vector(self.get_full_text())
            if self._vector is not None: faiss.normalize_L2(self._vector.reshape(1, -1))
        return self._vector

    def get_full_text(self) -> str:
        path = [curr.text_fragment for curr in self.get_path()]
        return "\n\n".join(path).strip()
    
    def get_path(self):
        path = []
        curr = self
        while curr is not None:
            path.append(curr)
            curr = curr.parent
        return list(reversed(path))

    def get_depth(self) -> int:
        return len(self.get_path()) - 1

    @property
    def value(self):
        return self.Q / self.N if self.N > 0 else 0

    def is_leaf(self):
        return len(self.children) == 0

class CG_MCTS:
    def __init__(self, llm_interface, novelty_db, config, theme: str):
        self.llm = llm_interface
        self.novelty_db = novelty_db
        self.config = config
        self.theme = theme
        self.root = MCTSNode(theme)
        self.v_target = self._initialize_target_vector()

    def _initialize_target_vector(self):
        print("\n--- Phase 1: Target Setting (Calculating v_target) ---")
        prompt = f"""
        First, think in a <think> block about how to deconstruct the scientific theme '{self.theme}' into core 'problems' and potential 'mechanisms'. Then, output the result as a JSON object.

        Example:
        <think>The user wants to break down 'AI for drug discovery'. The problems are the goals, like finding protein structures. The mechanisms are the AI tools used, like GNNs or Transformers.</think>
        ```json
        {{
            "problems": ["protein folding prediction", "molecule screening"],
            "mechanisms": ["graph neural networks", "reinforcement learning"]
        }}
        ```
        """
        response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.THEME_GEN_MAX_LENGTH, temperature=0.2)
        print(response)
        components = parse_llm_json_output(response)

        if not components or not components.get("problems") or not components.get("mechanisms"):
            print("Could not deconstruct theme. Using a random target vector.")
            return np.random.rand(self.llm.model.config.hidden_size)

        problem_concept = components["problems"][0]
        mechanism_concept = components["mechanisms"][0]
        print(f"Selected Pair: Problem='{problem_concept}', Mechanism='{mechanism_concept}'")

        v_p = self.llm.get_vector(problem_concept)
        v_m = self.llm.get_vector(mechanism_concept)

        proj_v_m_on_v_p = (np.dot(v_m, v_p) / np.dot(v_p, v_p)) * v_p
        v_m_ortho = v_m - proj_v_m_on_v_p
        v_target = v_p + self.config.ALPHA_NOVELTY * v_m_ortho
        faiss.normalize_L2(v_target.reshape(1, -1))
        print("v_target calculated and normalized.")
        return v_target

    def select(self, node: MCTSNode):
        while not node.is_leaf():
            if not node.children: break
            def score_func(c):
                if c.N == 0: return float('inf')
                exploitation = c.Q / c.N
                exploration = self.config.EXPLORATION_CONSTANT * math.sqrt(math.log(node.N) / c.N)
                direction_guidance = 0
                child_vec = c.get_vector(self.llm)
                if child_vec is not None and self.v_target is not None:
                    direction_guidance = self.config.W_DIR * np.dot(child_vec, self.v_target)
                return exploitation + exploration + direction_guidance
            node = max(node.children, key=score_func)
        return node

    def expand(self, node: MCTSNode):
        # REWRITTEN based on the "Generator-Proposer" model
        if node.is_terminal: return
        
        print(f"  Expanding node {node.get_depth()} with deep thought...")
        history_text = node.get_full_text()
        prompt = f"""
        You are a research scientist brainstorming. Based on the idea so far:
        '''{history_text}'''
        First, perform a detailed reasoning session within a <think> block. Analyze the idea, consider next steps, and weigh different options.
        After your reasoning, explicitly propose {self.config.K_EXPAND} distinct, concise, and actionable next steps, each prefixed with "Proposal 1:", "Proposal 2:", etc.
        You should begin with repeating the idea, and add more details after that.
        """
        # 1. Generator Call (Long-form Reasoning)
        long_form_response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.EXPAND_MAX_LENGTH)
        print(long_form_response)

        # 2. Proposer Step (Parsing and Node Creation)
        thoughts_match = re.search(r"<think>(.*?)</think>", long_form_response, re.DOTALL)
        thoughts = thoughts_match.group(1).strip() if thoughts_match else "No thoughts captured."

        proposals = re.findall(r"Proposal \d+:\s*(.*)", long_form_response)

        if proposals:
            for proposal_text in proposals[:self.config.K_EXPAND]:
                child_node = MCTSNode(proposal_text.strip(), parent=node, thoughts=thoughts)
                node.children.append(child_node)
        else:
            print("Warning: Expansion failed to generate valid proposals.")
            node.is_terminal = True

    def simulate_and_evaluate(self, node: MCTSNode):
        # This function remains fast and direct, as per the design.
        if not node.text_fragment: return 0.0
        total_value = 0
        for _ in range(self.config.N_ROLLOUTS):
            rollout_prompt = node.get_full_text()
            messages = [{"role": "user", "content": f"Continue this idea very **briefly and directly**. Provide enough details to understand the next logical step or implication: {rollout_prompt}"}]
            rollout_text = self.llm.generate_chat_completion(messages, self.config.ROLLOUT_MAX_LENGTH, temperature=0.8, thinking=False)

            #messages = f"Continue this idea very **briefly and directly**. Please provide enough details.: {rollout_prompt}"
            #messages = f"""Continue this idea very **briefly and directly**. Provide enough details to understand the next logical step or implication.
            #messages = f"""Follow th example, Continue this idea very **briefly and directly**. Provide enough details to understand the next logical step or implication.
            # ```example
            # Continue this idea:
            # Developing a new method for sustainable energy storage.
            #This method will focus on a novel all-solid-state lithium-sulfur battery architecture, employing a polymer-ceramic composite electrolyte (e.g., PEO-LLZO) and a hierarchical porous carbon cathode with a sulfur loading of 70 wt% to achieve a gravimetric energy density exceeding 500 Wh/kg and a cycle life of over 1000 cycles at 0.5C rate.           
            # ```
            # Continue this idea: 
            # {rollout_prompt}
            # """

            #rollout_text = self.llm.generate_completion(messages, self.config.ROLLOUT_MAX_LENGTH, temperature=0.8)
            print('ASKING::\n',messages)
            print('RESPONSE::\n',rollout_text)
            final_seq = (rollout_prompt + "\n\n" + rollout_text).strip()

            v_coherence, seq_vec = self.llm.get_prob_and_vector(final_seq)
            seq_vec = seq_vec.reshape(1, -1)
            faiss.normalize_L2(seq_vec)

            v_coherence_norm = 1 / (1 + math.exp(-v_coherence))
            similarities, _ = self.novelty_db.search(seq_vec, 1)
            v_novelty = 1.0 - similarities[0][0]
            
            v_progress = 0.5
            if node.parent and node.parent.get_vector(self.llm) is not None:
                parent_vec = node.parent.get_vector(self.llm)
                child_vec = node.get_vector(self.llm)
                if parent_vec is not None and child_vec is not None:
                    v_progress = (1.0 - np.dot(parent_vec.flatten(), child_vec.flatten())) / 2.0

            value = (self.config.W_COH * v_coherence_norm + self.config.W_NOV * v_novelty + self.config.W_PROG * v_progress)
            print(f'Cohere value:{v_coherence_norm}, Novel value: {v_novelty}, Progress value: {v_progress}, \nTotal Value:{value}')
            total_value += value
        return total_value / self.config.N_ROLLOUTS

    def backpropagate(self, node: MCTSNode, value: float):
        while node is not None:
            node.N += 1
            node.Q += value
            node = node.parent

    def search(self):
        print("\n--- Phase 2: Guided Search ---")
        if self.root.is_leaf():
            self.expand(self.root)
            for child in self.root.children:
                value = self.simulate_and_evaluate(child)
                self.backpropagate(child, value)

        for i in range(self.config.NUM_ITERATIONS):
            print(f"--- Iteration {i+1}/{self.config.NUM_ITERATIONS} ---")
            leaf_node = self.select(self.root)
            print(f"  Select ({leaf_node.get_depth()}): ...{leaf_node.text_fragment[-60:]}")
            
            if not leaf_node.is_terminal:
                self.expand(leaf_node)
                for child in leaf_node.children:
                    if child.N == 0:
                        value = self.simulate_and_evaluate(child)
                        print(f"  Simulate ({child.get_depth()}): ...{child.text_fragment[-60:]}, Value: {value:.4f}")
                        self.backpropagate(child, value)
            print("-" * 25)

    def get_best_sequence(self, debug=True):
        path_nodes = []
        node = self.root
        while node is not None:
            path_nodes.append(node)
            if not node.children: break
            node = max(node.children, key=lambda n: n.N if n.N > 0 else -1)

        main_text = "\n\n".join(n.text_fragment for n in path_nodes)
        if not debug: return main_text

        debug_parts = []
        for i, n in enumerate(path_nodes):
            debug_parts.append(f"--- Node {i} (Depth: {n.get_depth()}, Visits: {n.N}, Value: {n.value:.4f}) ---")
            if n.thoughts:
                # Display only a snippet of thoughts for brevity in the final print
                thought_snippet = (n.thoughts[:400] + '...') if len(n.thoughts) > 400 else n.thoughts
                debug_parts.append(f"  [REASONING]:\n{thought_snippet}")
            debug_parts.append(f"  [CONTENT]:\n{n.text_fragment}\n")
        debug_trace = "\n".join(debug_parts)
        
        return f"--- FINAL IDEA ---\n{main_text}\n\n\n--- DEBUG TRACE (BEST PATH) ---\n{debug_trace}"

# --- 6. Helper & Main Execution ---
def build_faiss_index(vectors: np.ndarray):
    print("Building FAISS index...")
    dimension = vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    faiss.normalize_L2(vectors)
    index.add(vectors)
    print("FAISS index built.")
    return index

if __name__ == "__main__":
    cfg = Config()
    llm_interface = LLMInterface(cfg.MODEL_NAME, cfg.DEVICE)

    novelty_documents = [
        "AlphaFold predicts protein structures using a deep learning approach.", 
        "CRISPR-Cas9 is a gene-editing tool that allows for precise modification of DNA sequences.", 
        "BERT is a transformer-based model for natural language understanding.",
        "Graphene is a single layer of carbon atoms arranged in a two-dimensional honeycomb lattice.",
        "Lithium-ion batteries are a type of rechargeable battery.",
        "Perovskite solar cells include a perovskite-structured compound as the light-harvesting active layer."
    ]
    
    print(f"\nVectorizing {len(novelty_documents)} base documents...")
    vectors = np.array([llm_interface.get_vector(text) for text in novelty_documents]).astype('float32')
    
    novelty_db = build_faiss_index(vectors)

    theme_generator = AutomatedThemeGenerator(llm_interface, novelty_documents, vectors, cfg)
    automated_theme = theme_generator.generate_theme()
    
    print(f"\n>>> Automatically Generated Theme: '{automated_theme}' <<<\n")
    
    mcts = CG_MCTS(llm_interface, novelty_db=novelty_db, config=cfg, theme=automated_theme)

    start_time = time.time()
    mcts.search()
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.2f} seconds.")

    final_idea_with_debug = mcts.get_best_sequence(debug=True)

    print("\n" + "="*20 + " Final Generated Idea " + "="*20)
    print(final_idea_with_debug)
