import torch
import numpy as np
import faiss
import math
import time
import re
import json
import random
from sklearn.cluster import KMeans
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- 1. Configuration ---
class Config:
    MODEL_NAME = "../Qwen3-0.6B"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    NUM_ITERATIONS = 30
    
    # MCTS Core
    EXPLORATION_CONSTANT = 1.5
    K_EXPAND = 3
    # N_ROLLOUTS is deprecated in Narrative MCTS, evaluation is direct.
    
    # LLM Generation Lengths
    EXPAND_MAX_LENGTH = 1024
    THEME_GEN_MAX_LENGTH = 1024

    # Automated Theme Generation
    NUM_CLUSTERS = 3
    
    # T-MCTS/PE Hyperparameters
    ALPHA_NOVELTY = 0.7
    W_DIR = 1.0
    W_COH = 0.5
    W_NOV = 0.3
    W_PROG = 0.2

# --- 2. Helper Functions ---
def parse_llm_json_output(response: str) -> dict | None:
    """Robustly parses JSON from LLM output that might include markdown."""
    try:
        match = re.search(r"```json\n(.*?)\n```", response, re.DOTALL)
        if match:
            json_str = match.group(1)
            return json.loads(json_str)
        else:
            return json.loads(response)
    except (json.JSONDecodeError, IndexError):
        print(f"Warning: Failed to parse LLM response as JSON.")
        return None

# --- 3. LLM Interface ---
class LLMInterface:
    def __init__(self, model_name, device):
        print(f"Loading model: {model_name} on {device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
            device_map="auto"
        ).eval()
        self.device = self.model.device
        print("Model loaded successfully.")

    def get_vector(self, text: str) -> np.ndarray:
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            vector = outputs.hidden_states[-1][0, -1, :].cpu().to(torch.float32).numpy()
            del inputs, outputs
            return vector

    def generate_chat_completion(self, messages: list, max_length: int, temperature: float = 0.7, thinking=True) -> str:
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=thinking)
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
        generated_ids = self.model.generate(
            model_inputs.input_ids,
            attention_mask=model_inputs.attention_mask,
            max_new_tokens=max_length,
            do_sample=True, top_p=0.9, temperature=temperature,
            pad_token_id=self.tokenizer.eos_token_id
        )
        response = self.tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0]
        del text, model_inputs, generated_ids
        return response

    def get_prob_and_vector(self, text: str) -> tuple[float, np.ndarray]:
        with torch.no_grad():
            try:
                inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
                outputs = self.model(**inputs, labels=inputs.input_ids, output_hidden_states=True)
                log_prob = -outputs.loss.item()
                vector = outputs.hidden_states[-1][0, -1, :].cpu().to(torch.float32).numpy()
                del inputs, outputs
                return log_prob, vector
            except Exception as e:
                print(f"Error in get_prob_and_vector: {e}")
                return 0.0, None

# --- 4. Automated Theme Generator ---
class AutomatedThemeGenerator:
    def __init__(self, llm_interface: LLMInterface, documents: list, vectors: np.ndarray, config: Config):
        print("\n--- Initializing Automated Theme Generator ---")
        self.llm = llm_interface
        self.documents = documents
        self.vectors = vectors
        self.config = config
        self.labels = None
        self.centroids = None
        self._build_knowledge_map()

    def _build_knowledge_map(self):
        print(f"Phase 0: Clustering {len(self.documents)} documents into {self.config.NUM_CLUSTERS} 'Conceptual Continents'...")
        kmeans = KMeans(n_clusters=self.config.NUM_CLUSTERS, random_state=42, n_init=10)
        self.labels = kmeans.fit_predict(self.vectors)
        self.centroids = kmeans.cluster_centers_
        print("Knowledge map built.")

    def generate_theme(self) -> str:
        print("\n--- Generating a new theme automatically ---")
        print("Phase 1 & 2: Sampling continents and finding concepts...")
        cluster_a_idx = random.randint(0, self.config.NUM_CLUSTERS - 1)
        distances = np.linalg.norm(self.centroids - self.centroids[cluster_a_idx], axis=1)
        sorted_indices = np.argsort(distances)
        num_clusters = self.config.NUM_CLUSTERS
        medium_dist_start = num_clusters // 3
        medium_dist_end = 2 * (num_clusters // 3)
        medium_distance_indices = sorted_indices[medium_dist_start:medium_dist_end]
        if not medium_distance_indices.any():
            medium_distance_indices = [i for i in sorted_indices if i != cluster_a_idx]
        cluster_b_idx = random.choice(medium_distance_indices)
        
        def get_representative_doc(cluster_idx):
            doc_indices = np.where(self.labels == cluster_idx)[0]
            return self.documents[random.choice(doc_indices)]

        concept_a = get_representative_doc(cluster_a_idx)
        concept_b = get_representative_doc(cluster_b_idx)

        print("Phase 3: Synthesizing a high-level theme...")
        prompt = f'''
        First, think step-by-step in a <think> block about how to creatively combine the two following scientific concepts. Then, output the final high-level, forward-looking research theme as a single, concise string.

        Concept 1: "{concept_a}"
        Concept 2: "{concept_b}"
        '''
        response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.THEME_GEN_MAX_LENGTH, temperature=0.5)
        theme = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL).strip()
        return theme if theme else response

# --- 5. MCTS Components (Narrative MCTS) ---
class MCTSNode:
    def __init__(self, narrative_block: str, parent=None, thoughts: str = None):
        self.narrative_block = narrative_block
        self.parent = parent
        self.thoughts = thoughts
        self.children = []
        self.N = 0
        self.Q = 0.0
        self.is_terminal = False
        self._vector = None

    def get_vector(self, llm_interface: LLMInterface) -> np.ndarray:
        if self._vector is None and self.get_full_text():
            # The vector represents the entire narrative up to this point
            self._vector = llm_interface.get_vector(self.get_full_text())
            if self._vector is not None: faiss.normalize_L2(self._vector.reshape(1, -1))
        return self._vector

    def get_full_text(self) -> str:
        path_nodes = self.get_path()
        # The root node's narrative_block is the theme, subsequent blocks are continuations.
        # We join them to form the full narrative.
        return "\n\n".join(n.narrative_block for n in path_nodes).strip()

    def get_path(self):
        path = []
        curr = self
        while curr is not None:
            path.append(curr)
            curr = curr.parent
        return list(reversed(path))

    def get_depth(self) -> int:
        return len(self.get_path()) - 1

    @property
    def value(self):
        return self.Q / self.N if self.N > 0 else 0

    def is_leaf(self):
        return len(self.children) == 0

class CG_MCTS:
    def __init__(self, llm_interface, novelty_db, config, theme: str):
        self.llm = llm_interface
        self.novelty_db = novelty_db
        self.config = config
        self.theme = theme
        # The root node contains the initial theme as the first narrative block
        self.root = MCTSNode(theme)
        self.v_target = self._initialize_target_vector()

    def _initialize_target_vector(self):
        print("\n--- Phase 1: Target Setting (Calculating v_target) ---")
        prompt = f'''
        First, think in a <think> block about how to deconstruct the scientific theme 
'{self.theme}
' into core \'problems\' and potential \'mechanisms\'. Then, output the result as a JSON object.

        Example:
        <think>The user wants to break down \'AI for drug discovery\'. The problems are the goals, like finding protein structures. The mechanisms are the AI tools used, like GNNs or Transformers.</think>
        ```json
        {{
            "problems": ["protein folding prediction", "molecule screening"],
            "mechanisms": ["graph neural networks", "reinforcement learning"]
        }}
        ```
        '''
        response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.THEME_GEN_MAX_LENGTH, temperature=0.2)
        print(response)
        components = parse_llm_json_output(response)

        if not components or not components.get("problems") or not components.get("mechanisms"):
            print("Could not deconstruct theme. Using a random target vector.")
            return np.random.rand(self.llm.model.config.hidden_size)

        problem_concept = components["problems"][0]
        mechanism_concept = components["mechanisms"][0]
        print(f"Selected Pair: Problem='{problem_concept}', Mechanism='{mechanism_concept}'")

        v_p = self.llm.get_vector(problem_concept)
        v_m = self.llm.get_vector(mechanism_concept)

        proj_v_m_on_v_p = (np.dot(v_m, v_p) / np.dot(v_p, v_p)) * v_p
        v_m_ortho = v_m - proj_v_m_on_v_p
        v_target = v_p + self.config.ALPHA_NOVELTY * v_m_ortho
        faiss.normalize_L2(v_target.reshape(1, -1))
        print("v_target calculated and normalized.")
        return v_target

    def select(self, node: MCTSNode):
        while not node.is_leaf():
            if not node.children: break
            def score_func(c):
                if c.N == 0: return float('inf')
                exploitation = c.Q / c.N
                exploration = self.config.EXPLORATION_CONSTANT * math.sqrt(math.log(node.N) / c.N)
                direction_guidance = 0
                child_vec = c.get_vector(self.llm)
                if child_vec is not None and self.v_target is not None:
                    direction_guidance = self.config.W_DIR * np.dot(child_vec, self.v_target)
                return exploitation + exploration + direction_guidance
            node = max(node.children, key=score_func)
        return node

    def expand(self, node: MCTSNode):
        # ULTIMATE VERSION: "Principle-Guided Open-Ended Generation"
        if node.is_terminal: return

        print(f"  Expanding node {node.get_depth()} with principle-guided generation...")
        idea_so_far = node.get_full_text()
        
        prompt = f'''
        You are a world-class Principal Investigator, known for writing clear, compelling, and fundable research proposals.
        Your current task is to expand on the following research idea:
        ---
        {idea_so_far}
        ---
        You will now write the **next logical section** of this proposal.

        To do this, you must follow these core principles of scientific writing:

        1.  **Progressive Deepening:** Your new section MUST logically follow from the existing text. It should deepen the idea, moving from a general concept to specific details, or from a hypothesis to a method of testing it. Do not repeat existing information; build upon it.

        2.  **Concrete Detail:** Be specific and avoid vague language. If you propose an experiment, name the key techniques, models, or datasets. If you describe a mechanism, explain it with sufficient detail for another expert to understand.

        3.  **Critical Thinking:** Briefly acknowledge potential challenges, limitations, or alternative approaches to your proposed section. This demonstrates foresight.

        Based on these principles, please propose {self.config.K_EXPAND} distinct, well-reasoned, and detailed "next sections" for this research plan. Prefix each version with `[Option 1]`, `[Option 2]`, etc.
        
        Continue this:
         {idea_so_far}
        '''
        
        response = self.llm.generate_chat_completion([{"role": "user", "content": prompt}], self.config.EXPAND_MAX_LENGTH, temperature=0.6) # Temp slightly higher for creativity
        print(response)

        thoughts_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
        thoughts = thoughts_match.group(1).strip() if thoughts_match else "No thoughts captured."

        # Regex to capture each Option block
        option_blocks = re.findall(r'(\[Option \d+\](.*?))(?=\[Option \d+\]|\Z)', response, re.DOTALL)

        if option_blocks:
            for _, option_content in option_blocks:
                full_option_text = option_content.strip()
                if full_option_text:
                    child_node = MCTSNode(full_option_text, parent=node, thoughts=thoughts)
                    node.children.append(child_node)
        else:
            # Fallback for models that might just provide one solid block without headers
            fallback_content = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL).strip()
            if len(fallback_content) > 100: # Arbitrary threshold for substantial content
                 print("Warning: No [Option X] headers found. Using entire response as a single option.")
                 child_node = MCTSNode(fallback_content, parent=node, thoughts=thoughts)
                 node.children.append(child_node)
            else:
                print("Warning: Expansion failed to generate valid options.")
                node.is_terminal = True

    def simulate_and_evaluate(self, node: MCTSNode):
        # REWRITTEN for "Narrative MCTS" - evaluates the full text directly
        if not node.narrative_block: return 0.0
        
        full_text = node.get_full_text()
        print(f"  Simulating/Evaluating full text of depth {node.get_depth()}...")

        v_coherence, seq_vec = self.llm.get_prob_and_vector(full_text)
        if seq_vec is None: return 0.0

        seq_vec = seq_vec.reshape(1, -1)
        faiss.normalize_L2(seq_vec)

        # Coherence Value
        v_coherence_norm = 1 / (1 + math.exp(-v_coherence))
        
        # Novelty Value
        similarities, _ = self.novelty_db.search(seq_vec, 1)
        v_novelty = 1.0 - similarities[0][0]
        
        # Progress Value
        v_progress = 0.0 # Root node has no progress
        if node.parent:
            parent_vec = node.parent.get_vector(self.llm)
            # The node's own vector is seq_vec, which we just computed
            child_vec = seq_vec.flatten()
            if parent_vec is not None and child_vec is not None:
                # Cosine distance is 1 - cosine similarity
                v_progress = (1.0 - np.dot(parent_vec.flatten(), child_vec))

        value = (self.config.W_COH * v_coherence_norm + 
                 self.config.W_NOV * v_novelty + 
                 self.config.W_PROG * v_progress)
        
        print(f'  Coh={v_coherence_norm:.3f}, Nov={v_novelty:.3f}, Prog={v_progress:.3f} -> Total Value: {value:.4f}')
        # In this new model, N_ROLLOUTS is implicitly 1, as we do one direct, high-quality evaluation.
        return value

    def backpropagate(self, node: MCTSNode, value: float):
        while node is not None:
            node.N += 1
            node.Q += value
            node = node.parent

    def search(self):
        print("\n--- Phase 2: Guided Narrative Search ---")
        if self.root.is_leaf():
            self.expand(self.root)
            # Initial expansion and evaluation of the first children
            for child in self.root.children:
                value = self.simulate_and_evaluate(child)
                self.backpropagate(child, value)

        for i in range(self.config.NUM_ITERATIONS):
            print(f"--- Iteration {i+1}/{self.config.NUM_ITERATIONS} ---")
            leaf_node = self.select(self.root)
            print(f"  Select ({leaf_node.get_depth()}): ...{leaf_node.narrative_block[-80:].strip()}")
            
            if not leaf_node.is_terminal:
                # Expand the leaf node, creating new children
                self.expand(leaf_node)
                # Evaluate each new child
                for child in leaf_node.children:
                    if child.N == 0:
                        value = self.simulate_and_evaluate(child)
                        self.backpropagate(child, value)
            print("-" * 25)

    def get_best_sequence(self, debug=True):
        path_nodes = []
        node = self.root
        while node is not None:
            path_nodes.append(node)
            if not node.children: break
            # Choose the most visited child to continue the path
            node = max(node.children, key=lambda n: n.N if n.N > 0 else -1)

        # The final text is the concatenation of all narrative blocks on the best path
        main_text = "\n\n".join(n.narrative_block for n in path_nodes)
        
        if not debug: return main_text

        debug_parts = []
        for i, n in enumerate(path_nodes):
            debug_parts.append(f"--- Node {i} (Depth: {n.get_depth()}, Visits: {n.N}, Value: {n.value:.4f}) ---")
            if n.thoughts:
                thought_snippet = (n.thoughts[:400] + '...') if len(n.thoughts) > 400 else n.thoughts
                debug_parts.append(f"  [REASONING]:\n{thought_snippet}")
            debug_parts.append(f"  [NARRATIVE BLOCK]:\n{n.narrative_block}\n")
        debug_trace = "\n".join(debug_parts)
        
        return f"--- FINAL NARRATIVE ---\n{main_text}\n\n\n--- DEBUG TRACE (BEST PATH) ---\n{debug_trace}"

# --- 6. Helper & Main Execution ---
def build_faiss_index(vectors: np.ndarray):
    print("Building FAISS index...")
    dimension = vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    faiss.normalize_L2(vectors)
    index.add(vectors)
    print("FAISS index built.")
    return index

if __name__ == "__main__":
    cfg = Config()
    llm_interface = LLMInterface(cfg.MODEL_NAME, cfg.DEVICE)

    novelty_documents = [
        "AlphaFold predicts protein structures using a deep learning approach.", 
        "CRISPR-Cas9 is a gene-editing tool that allows for precise modification of DNA sequences.", 
        "BERT is a transformer-based model for natural language understanding.",
        "Graphene is a single layer of carbon atoms arranged in a two-dimensional honeycomb lattice.",
        "Lithium-ion batteries are a type of rechargeable battery.",
        "Perovskite solar cells include a perovskite-structured compound as the light-harvesting active layer."
    ]
    
    print(f"\nVectorizing {len(novelty_documents)} base documents...")
    vectors = np.array([llm_interface.get_vector(text) for text in novelty_documents]).astype('float32')
    
    novelty_db = build_faiss_index(vectors)

    theme_generator = AutomatedThemeGenerator(llm_interface, novelty_documents, vectors, cfg)
    automated_theme = theme_generator.generate_theme()
    
    print(f"\n>>> Automatically Generated Theme: '{automated_theme}' <<<\n")
    
    mcts = CG_MCTS(llm_interface, novelty_db=novelty_db, config=cfg, theme=automated_theme)

    start_time = time.time()
    mcts.search()
    end_time = time.time() 
    print(f"Search completed in {end_time - start_time:.2f} seconds.")

    final_narrative = mcts.get_best_sequence(debug=True)

    print("\n" + "="*20 + " Final Generated Narrative " + "="*20)
    print(final_narrative)
