import torch
import numpy as np
import faiss
import math
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- 1. Configuration ---
class Config:
    # Model and Device
    MODEL_NAME = "Qwen3-0.6B" # Using a 7B model as recommended
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # MCTS Hyperparameters
    NUM_ITERATIONS = 1024 # "Thinking budget", increase for better results
    EXPLORATION_CONSTANT = 1.5
    K_EXPAND = 3 # Number of child nodes to expand
    ROLLOUT_MAX_LENGTH = 512 # Max length for simulation rollouts

    # Guided UCT and Evaluation Weights
    W_DIR = 0.3 # Direction guidance weight
    W_COH = 0.6 # Coherence weight
    W_NOV = 0.4 # Novelty weight

    # Length reward 
    W_LEN = 0.25 # Weight for length bonus 
    TARGET_LENGTH = 40 # Target token length for bonus calculation  

    # COP Hyperparameters
    ALPHA = 0.5 # Orthogonal projection weight
    BETA = 0.5 # Realism factor (anchoring)

# --- 2. LLM Interface ---
class LLMInterface:
    """A wrapper for the Hugging Face model to handle low-level operations."""
    def __init__(self, model_name, device):
        print(f"Loading model: {model_name} on {device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            device_map="auto"
        )
        self.device = self.model.device
        print("Model loaded successfully.")

    def extract_vector(self, text: str) -> np.ndarray:
        """
        Encodes text and returns the hidden state of the last token of the last layer.
        This is the core function for turning text into a vector.
        """
        # Ensure the model is in evaluation mode
        self.model.eval()
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            # Directly call the model to get hidden states
            outputs = self.model(**inputs, output_hidden_states=True)
            # hidden_states is a tuple of (n_layers + 1) tensors
            # We want the last layer's hidden states
            last_layer_hidden_states = outputs.hidden_states[-1]
            # We want the hidden state of the very last token
            last_token_hidden_state = last_layer_hidden_states[0, -1, :]
            return last_token_hidden_state.cpu().to(torch.float32).numpy()

    def get_next_token_probs(self, text: str, top_k: int):
        """
        Gets the probability distribution for the next token.
        Used in the 'Expansion' step of MCTS.
        """
        self.model.eval()
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs)
            logits = outputs.logits[0, -1, :]
            probs = torch.softmax(logits, dim=-1)
            top_k_probs, top_k_indices = torch.topk(probs, top_k)
            return top_k_indices.cpu().tolist(), top_k_probs.cpu().tolist()

    def generate_rollout(self, text: str, max_length: int) -> str:
        """
        Generates a continuation of the text for the 'Simulation' step.
        """
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=max_length,
            do_sample=True,
            top_p=0.9,
            temperature=0.7
        )
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        return self.tokenizer.decode(output_ids, skip_special_tokens=True)

    def get_sequence_log_prob(self, text: str) -> float:
        """
        Calculates the average log probability of a sequence.
        Used for the 'Coherence' score.
        """
        self.model.eval()
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
            input_ids = inputs.input_ids
            outputs = self.model(**inputs, labels=input_ids)
            # The loss is the negative log likelihood
            # A lower loss means a higher probability
            # We negate it to get a "higher is better" score
            return -outputs.loss.item()

# --- 3. MCTS Components ---
class MCTSNode:
    """A node in the Monte Carlo Search Tree."""
    def __init__(self, text_sequence: str, parent=None, token_id=None, prob=0.0):
        self.text_sequence = text_sequence
        self.parent = parent
        self.children = []
        self.N = 0  # Visit count
        self.Q = 0.0  # Total value
        self.token_id = token_id
        self.prob = prob # Probability of this token from parent

    @property
    def value(self):
        return self.Q / self.N if self.N > 0 else 0

    def is_leaf(self):
        return len(self.children) == 0

class CG_MCTS:
    """Concept-Guided Monte Carlo Tree Search."""
    def __init__(self, llm_interface, v_target, novelty_db, config):
        self.llm = llm_interface
        self.v_target = v_target / np.linalg.norm(v_target) # Normalize for cosine similarity
        self.novelty_db = novelty_db
        self.config = config
        self.root = MCTSNode("<|im_start|>user\n") # Start with the user prompt starter

    def select(self, node: MCTSNode):
        """Select a node to expand using the Guided UCT formula."""
        while not node.is_leaf():
            best_child = None
            max_score = -float('inf')

            for child in node.children:
                if child.N == 0:
                    # Prioritize unvisited children
                    best_child = child
                    break

                # 1. Utilization Value
                utilization = child.value
                # 2. Exploration Value
                exploration = self.config.EXPLORATION_CONSTANT * math.sqrt(math.log(node.N) / child.N)
                # 3. Direction Guidance Value
                child_vec = self.llm.extract_vector(child.text_sequence)
                child_vec_norm = np.linalg.norm(child_vec)
                if child_vec_norm == 0:
                    cosine_sim = 0
                else:
                    cosine_sim = np.dot(child_vec, self.v_target) / child_vec_norm
                
                direction = self.config.W_DIR * cosine_sim

                score = utilization + exploration + direction
                if score > max_score:
                    max_score = score
                    best_child = child
            node = best_child
        return node

    def expand(self, node: MCTSNode):
        """Expand the leaf node by creating new children."""
        if node.N > 0 or node == self.root:
            next_token_ids, probs = self.llm.get_next_token_probs(node.text_sequence, self.config.K_EXPAND)
            for token_id, prob in zip(next_token_ids, probs):
                child_text = node.text_sequence + self.llm.tokenizer.decode([token_id])
                child_node = MCTSNode(child_text, parent=node, token_id=token_id, prob=prob)
                node.children.append(child_node)
            # Return the first new child for simulation
            return node.children[0] if node.children else node
        return node

    def simulate_and_evaluate(self, node: MCTSNode):
        """Simulate a rollout and evaluate its value."""
        rollout_seq = self.llm.generate_rollout(node.text_sequence, self.config.ROLLOUT_MAX_LENGTH)
        full_seq = node.text_sequence + rollout_seq

        # V_coherence
        v_coherence = self.llm.get_sequence_log_prob(full_seq)

        # V_novelty
        seq_vec = self.llm.extract_vector(full_seq).astype('float32').reshape(1, -1)
        faiss.normalize_L2(seq_vec) # Normalize for IP search
        similarities, _ = self.novelty_db.search(seq_vec, 1)
        # Novelty is 1 - cosine_similarity. Higher is more novel.
        v_novelty = 1.0 - similarities[0][0]

        # Length Bonus
        seq_len = len(self.llm.tokenizer.encode(full_seq))
        length_bonus = min(1.0, seq_len / self.config.TARGET_LENGTH)

        # Normalize scores
        v_coherence_norm = 1 / (1 + math.exp(-v_coherence)) # Sigmoid scaling

        # v_novelty is already in a predictable range [0, 2]
        v_novelty_norm = v_novelty / 2.0

        value = (self.config.W_COH * v_coherence_norm + 
                 self.config.W_NOV * v_novelty_norm +
                 self.config.W_LEN * length_bonus)
        return value

    def backpropagate(self, node: MCTSNode, value: float):
        """Propagate the value back up the tree."""
        while node is not None:
            node.N += 1
            node.Q += value
            node = node.parent

    def search(self):
        """Run the main MCTS loop."""
        for i in range(self.config.NUM_ITERATIONS):
            print(f"--- Iteration {i+1}/{self.config.NUM_ITERATIONS} ---")
            # 1. Selection
            leaf_node = self.select(self.root)
            # print(f"  Select: ...{leaf_node.text_sequence[-30:]}")
            print(f"  Select: ...{leaf_node.text_sequence}")

            # 2. Expansion
            if leaf_node.N > 0:
                leaf_node = self.expand(leaf_node)
                print(f"  Expand: -> ...{leaf_node.text_sequence[-30:]}")

            # 3. Simulation & Evaluation
            print("  Simulate & Evaluate...")
            value = self.simulate_and_evaluate(leaf_node)
            print(f"  Value: {value:.4f}")

            # 4. Backpropagation
            print("  Backpropagate...")
            self.backpropagate(leaf_node, value)
            print("-" * (20 + 5))


    def get_best_sequence(self, max_len=100):
        """Get the best sequence by choosing the most visited nodes."""
        node = self.root
        sequence = node.text_sequence
        while not node.is_leaf() and len(self.llm.tokenizer.encode(sequence)) < max_len:
            best_child = max(node.children, key=lambda n: n.N)
            node = best_child
        return node.text_sequence

# --- 4. Helper Functions ---
def build_faiss_index(texts: list, llm_interface: LLMInterface, use_gpu: bool = False):
    """Builds a FAISS index from a list of texts, returning the index and vectors."""
    print(f"Building FAISS index for {len(texts)} texts...")
    vectors = [llm_interface.extract_vector(text) for text in texts]
    vectors = np.array(vectors)
    dimension = vectors.shape[1]
    
    index_cpu = faiss.IndexFlatIP(dimension) # Use Inner Product for Cosine Similarity
    faiss.normalize_L2(vectors) # Normalize vectors for cosine similarity
    index_cpu.add(vectors)
    
    if use_gpu:
        try:
            res = faiss.StandardGpuResources()
            gpu_device = 0
            index_gpu = faiss.index_cpu_to_gpu(res, gpu_device, index_cpu)
            print(f"Successfully moved FAISS index to GPU device {gpu_device}.")
            return index_gpu, vectors
        except Exception as e:
            print(f"Failed to move FAISS index to GPU. Error: {e}")
            print("Falling back to CPU-based FAISS index.")
            return index_cpu, vectors
    else:
        print("Using CPU-based FAISS index.")
        return index_cpu, vectors

def calculate_target_vector(problem: str, mechanism: str, anchor_db: faiss.Index, anchor_vectors: np.ndarray, llm: LLMInterface, config: Config):
    """Implements Concept Orthogonal Projection (COP)."""
    print("Calculating target vector using COP...")
    v_problem = llm.extract_vector(problem)
    v_mechanism = llm.extract_vector(mechanism)

    # Project v_mechanism onto v_problem
    proj_on_problem = (np.dot(v_mechanism, v_problem) / np.dot(v_problem, v_problem)) * v_problem
    # Get the orthogonal component
    v_ortho = v_mechanism - proj_on_problem

    # v_raw_new = v_problem + alpha * v_ortho
    v_raw_new = v_problem + config.ALPHA * v_ortho
    v_raw_new = v_raw_new.astype('float32').reshape(1, -1)

    # Find nearest anchor
    _, indices = anchor_db.search(v_raw_new, 1)
    retrieved_index = indices[0][0]
    v_anchor = anchor_vectors[retrieved_index]

    # v_target = (1-beta)*v_raw_new + beta*v_anchor
    v_target = (1 - config.BETA) * v_raw_new.flatten() + config.BETA * v_anchor
    print("Target vector calculated.")
    return v_target

# --- 5. Main Execution ---
if __name__ == "__main__":
    cfg = Config()
    llm_interface = LLMInterface(cfg.MODEL_NAME, cfg.DEVICE)

    # --- Step 1: Prepare Databases (using dummy data) ---
    # In a real scenario, these would be large datasets.
    anchor_concepts = [
        "Machine learning", "Deep learning", "Genomics", "CRISPR",
        "Quantum computing", "Drug discovery", "Protein folding"
    ]
    novelty_documents = [
        "AlphaFold is a deep learning system that predicts protein structures.",
        "CRISPR-Cas9 is a gene-editing tool that allows for precise modification of DNA.",
        "BERT is a transformer-based model for natural language processing.",
        "Reinforcement learning has been used to master games like Go and Chess."
    ]

    use_gpu_for_faiss = (cfg.DEVICE == "cuda")

    anchor_db, anchor_vectors = build_faiss_index(anchor_concepts, llm_interface, use_gpu=use_gpu_for_faiss)
    novelty_db, _ = build_faiss_index(novelty_documents, llm_interface, use_gpu=use_gpu_for_faiss) # We don't need novelty vectors on CPU

    # --- Step 2: Define Problem and Calculate Target Vector ---
    problem_concept = "Discovering new materials for batteries"
    mechanism_concept = "Using generative adversarial networks (GANs)"

    v_target = calculate_target_vector(
        problem=problem_concept,
        mechanism=mechanism_concept,
        anchor_db=anchor_db,
        anchor_vectors=anchor_vectors, # Pass the CPU vectors
        llm=llm_interface,
        config=cfg
    )

    # --- Step 3: Run CG-MCTS ---
    print("\n" + "="*20 + " Starting CG-MCTS Search " + "="*20)
    mcts = CG_MCTS(llm_interface, v_target, novelty_db, cfg)
    
    # Add the initial prompt to the root node
    initial_prompt = f"My goal is to find a novel scientific idea for '{problem_concept}' inspired by '{mechanism_concept}'. Please begin generating a concept.\nIdea:"
    mcts.root.text_sequence += initial_prompt
    
    start_time = time.time()
    mcts.search()
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.2f} seconds.")

    # --- Step 4: Get Final Result ---
    final_idea = mcts.get_best_sequence(max_len=1024)

    print("\n" + "="*20 + " Final Generated Idea " + "="*20)
    # Clean up the output
    final_idea_cleaned = final_idea.replace("<|im_start|>user\n", "").replace(initial_prompt, "")
    print(final_idea)

