import torch
import numpy as np
import faiss
import math
import time
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- 1. Configuration ---
class Config:
    # Model and Device
    MODEL_NAME = "Qwen3-0.6B"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # MCTS Hyperparameters
    NUM_ITERATIONS = 64
    EXPLORATION_CONSTANT = 1.5
    K_EXPAND = 3
    SIMULATION_MAX_LENGTH = 60

    # COP Hyperparameters
    ALPHA = 0.5 # Orthogonal projection weight
    BETA = 0.5 # Realism factor (anchoring)

    # Evaluation and Guidance Weights
    W_COH = 0.6 # Coherence weight (in simulation)
    W_NOV = 0.4 # Novelty weight (in simulation)
    W_DIR = 0.3 # Direction guidance weight (in selection)
    W_PROMPT_SIM = 0.5 # Penalty for being too similar to the initial prompt

# --- 2. LLM Interface ---
class LLMInterface:
    def __init__(self, model_name, device):
        print(f"Loading model: {model_name} on {device}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
            device_map="auto"
        ).eval()
        self.device = self.model.device
        self.stop_tokens_ids = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids('.'),
            self.tokenizer.convert_tokens_to_ids('\n')
        ]
        self.stop_tokens_ids = [tid for tid in self.stop_tokens_ids if tid is not None and tid != self.tokenizer.unk_token_id]
        print("Model loaded successfully.")

    def _apply_template(self, text: str, add_generation_prompt: bool) -> str:
        messages = [{"role": "user", "content": text}]
        return self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=add_generation_prompt
        )

    def extract_vector(self, text: str) -> np.ndarray:
        templated_text = self._apply_template(text, add_generation_prompt=False)
        with torch.no_grad():
            inputs = self.tokenizer(templated_text, return_tensors="pt").to(self.device)
            outputs = self.model(**inputs, output_hidden_states=True)
            last_layer_hidden_states = outputs.hidden_states[-1]
            last_token_hidden_state = last_layer_hidden_states[0, -1, :]
            return last_token_hidden_state.cpu().to(torch.float32).numpy()

    def generate_sentences(self, text: str, num_sentences: int, max_length: int) -> list[str]:
        templated_text = self._apply_template(text, add_generation_prompt=True)
        input_ids = self.tokenizer.encode(templated_text, return_tensors='pt').to(self.device)
        
        generated_sequences = self.model.generate(
            input_ids,
            max_new_tokens=max_length,
            num_return_sequences=num_sentences,
            do_sample=True, top_p=0.9, temperature=0.8,
            eos_token_id=self.stop_tokens_ids,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        new_texts = []
        for seq in generated_sequences:
            new_tokens = seq[input_ids.shape[-1]:]
            new_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
            # Filter out the model's 'thinking' process
            if new_text and "<think>" not in new_text:
                new_texts.append(new_text)
        return new_texts

    def get_sequence_log_prob(self, text: str) -> float:
        templated_text = self._apply_template(text, add_generation_prompt=False)
        with torch.no_grad():
            inputs = self.tokenizer(templated_text, return_tensors="pt").to(self.device)
            input_ids = inputs.input_ids
            outputs = self.model(**inputs, labels=input_ids)
            return -outputs.loss.item()

# --- 3. MCTS Components ---
class MCTSNode:
    def __init__(self, text_sequence: str, parent=None):
        self.text_sequence = text_sequence
        self.parent = parent
        self.children = []
        self.N = 0
        self.Q = 0.0
        self.is_terminal = False

    @property
    def value(self):
        return self.Q / self.N if self.N > 0 else 0

    def is_leaf(self):
        return len(self.children) == 0

class CG_MCTS:
    def __init__(self, llm_interface, v_target, novelty_db, config, initial_prompt):
        self.llm = llm_interface
        self.v_target = v_target / np.linalg.norm(v_target)
        self.novelty_db = novelty_db
        self.config = config
        self.root = MCTSNode(initial_prompt)
        # Store the initial prompt's vector for the similarity penalty
        self.v_prompt = self.llm.extract_vector(initial_prompt)
        self.v_prompt = self.v_prompt / np.linalg.norm(self.v_prompt)

    def select(self, node: MCTSNode):
        while not node.is_leaf() and not node.is_terminal:
            best_child = None
            max_score = -float('inf')
            for child in node.children:
                if child.N == 0:
                    score = float('inf')
                else:
                    utilization = child.value
                    exploration = self.config.EXPLORATION_CONSTANT * math.sqrt(math.log(node.N) / child.N)
                    child_vec = self.llm.extract_vector(child.text_sequence)
                    child_vec_norm = np.linalg.norm(child_vec)
                    cosine_sim = np.dot(child_vec, self.v_target) / child_vec_norm if child_vec_norm > 0 else 0
                    direction = self.config.W_DIR * cosine_sim
                    score = utilization + exploration + direction
                if score > max_score:
                    max_score = score
                    best_child = child
            if best_child is None: return node # Handle case where there are no children
            node = best_child
        return node

    def expand(self, node: MCTSNode):
        if node.is_terminal or (node.N > 0 and not node.is_leaf()):
            return node
        candidate_sentences = self.llm.generate_sentences(node.text_sequence, self.config.K_EXPAND, self.config.SIMULATION_MAX_LENGTH)
        if not candidate_sentences:
            node.is_terminal = True
            return node
        for sentence in candidate_sentences:
            new_text = (node.text_sequence + " " + sentence).strip()
            if not any(child.text_sequence == new_text for child in node.children):
                child_node = MCTSNode(new_text, parent=node)
                node.children.append(child_node)
        return node.children[0] if node.children else node

    def simulate_and_evaluate(self, node: MCTSNode):
        full_seq = node.text_sequence
        v_coherence = self.llm.get_sequence_log_prob(full_seq)
        seq_vec = self.llm.extract_vector(full_seq).astype('float32')
        seq_vec_norm = np.linalg.norm(seq_vec)
        if seq_vec_norm > 0:
            seq_vec = seq_vec / seq_vec_norm
        
        # V_novelty (against existing documents)
        similarities, _ = self.novelty_db.search(seq_vec.reshape(1, -1), 1)
        v_novelty = 1.0 - similarities[0][0]
        
        # Prompt Similarity Penalty
        prompt_similarity = np.dot(seq_vec.flatten(), self.v_prompt.flatten())
        v_prompt_penalty = self.config.W_PROMPT_SIM * prompt_similarity

        v_coherence_norm = 1 / (1 + math.exp(-v_coherence))
        v_novelty_norm = np.clip(v_novelty, 0, 1)

        # Subtract the penalty from the value
        value = (self.config.W_COH * v_coherence_norm +
                 self.config.W_NOV * v_novelty_norm) - v_prompt_penalty
        return value

    def backpropagate(self, node: MCTSNode, value: float):
        while node is not None:
            node.N += 1
            node.Q += value
            node = node.parent

    def search(self):
        if self.root.is_leaf():
            print("--- Initial Root Expansion ---")
            self.expand(self.root)
            if self.root.children:
                for child in self.root.children:
                    print(f"  Simulating initial child: ...{child.text_sequence[-80:]}")
                    value = self.simulate_and_evaluate(child)
                    self.backpropagate(child, value)
            print("-" * 25)

        for i in range(self.config.NUM_ITERATIONS):
            print(f"--- Iteration {i+1}/{self.config.NUM_ITERATIONS} ---")
            leaf_node = self.select(self.root)
            print(f"  Select: ...{leaf_node.text_sequence[-80:]}")
            expanded_node = self.expand(leaf_node)
            if expanded_node != leaf_node:
                print(f"  Expand: -> ...{expanded_node.text_sequence[-80:]}")
                node_to_sim = expanded_node
            else:
                node_to_sim = leaf_node
            print("  Simulate & Evaluate...")
            value = self.simulate_and_evaluate(node_to_sim)
            print(f"  Value: {value:.4f}")
            print("  Backpropagate...")
            self.backpropagate(node_to_sim, value)
            print("-" * 25)

    def get_best_sequence(self):
        node = self.root
        while not node.is_leaf():
            if not node.children: break
            node = max(node.children, key=lambda n: n.N)
        return node.text_sequence

# --- 4. Helper Functions ---
def build_faiss_index(texts: list, llm_interface: LLMInterface):
    print(f"Building FAISS index for {len(texts)} texts...")
    vectors = np.array([llm_interface.extract_vector(text) for text in texts]).astype('float32')
    dimension = vectors.shape[1]
    index = faiss.IndexFlatIP(dimension)
    faiss.normalize_L2(vectors)
    index.add(vectors)
    print("FAISS index built.")
    return index, vectors

def calculate_target_vector(problem: str, mechanism: str, anchor_db: faiss.Index, anchor_vectors: np.ndarray, llm: LLMInterface, config: Config):
    print("Calculating target vector using COP...")
    v_problem = llm.extract_vector(problem)
    v_mechanism = llm.extract_vector(mechanism)
    proj_on_problem = (np.dot(v_mechanism, v_problem) / np.dot(v_problem, v_problem)) * v_problem
    v_ortho = v_mechanism - proj_on_problem
    v_raw_new = v_problem + config.ALPHA * v_ortho
    v_raw_new = v_raw_new.astype('float32').reshape(1, -1)
    _, indices = anchor_db.search(v_raw_new, 1)
    retrieved_index = indices[0][0]
    v_anchor = anchor_vectors[retrieved_index]
    v_target = (1 - config.BETA) * v_raw_new.flatten() + config.BETA * v_anchor
    print("Target vector calculated.")
    return v_target

# --- 5. Main Execution ---
if __name__ == "__main__":
    cfg = Config()
    llm_interface = LLMInterface(cfg.MODEL_NAME, cfg.DEVICE)

    anchor_concepts = ["Machine learning", "Deep learning", "Genomics", "Quantum computing", "Drug discovery"]
    novelty_documents = ["AlphaFold predicts protein structures.", "CRISPR-Cas9 is a gene-editing tool.", "BERT is a transformer-based model."]

    anchor_db, anchor_vectors = build_faiss_index(anchor_concepts, llm_interface)
    novelty_db, _ = build_faiss_index(novelty_documents, llm_interface)

    problem_concept = "Discovering new materials for batteries"
    mechanism_concept = "Using generative adversarial networks (GANs)"
    
    v_target = calculate_target_vector(
        problem=problem_concept,
        mechanism=mechanism_concept,
        anchor_db=anchor_db,
        anchor_vectors=anchor_vectors,
        llm=llm_interface,
        config=cfg
    )

    print("\n" + "="*20 + " Starting CG-MCTS Search " + "="*20)
    
    # New, less conversational prompt
    # initial_prompt = f"My goal is to find a novel scientific idea for '{problem_concept}' inspired by '{mechanism_concept}'. Please begin generating a concept.\nIdea:" 
    initial_prompt = f"Here is a novel scientific idea for '{problem_concept}' inspired by '{mechanism_concept}'.\n\nIdea:"

    mcts = CG_MCTS(llm_interface, v_target, novelty_db, cfg, initial_prompt)
    
    start_time = time.time()
    mcts.search()
    end_time = time.time()
    print(f"Search completed in {end_time - start_time:.2f} seconds.")

    final_idea = mcts.get_best_sequence()

    print("\n" + "="*20 + " Final Generated Idea " + "="*20)
    print(final_idea)