import os
import sys
import json
import requests
import numpy as np
import faiss
from typing import TypedDict, List, Dict, Any
from dotenv import load_dotenv

from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from rank_bm25 import BM25Okapi
from enum import Enum
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI  
from openai import OpenAI              

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from prompt_format import generate_user_prompt
import config

# --- Step 0: Configure your environment ---
load_dotenv("keys.env")

# --- Step 1: Define the State for the Graph ---
class GraphState(TypedDict):
    """
    Represents the state of our ProAgent graph.

    Attributes:
        sample: The initial user query.
        task_type: The type of task classified by the planner.
        retrieval_query: The query unsed for retrieval-augmented generation.
        generation_input: The input unsed for generating prompt.
        documents: A list of retrieved documents from the knowledge base.
        generation: The final generated response from the LLM.
    """
    sample: Dict[str, Any]
    task_type: str
    retrieval_query: str
    generation_input: Dict[str, Any]
    documents: List[str]
    generation: str

# --- Step 2: Set up the Tools (Planner, Retriever, Models) ---
# --- 2a. The Planner (Intent Recognition) ---
# You can replace with your own API key
llm_planner = ChatOpenAI(
    model=config.PLANNER_MODEL_NAME, 
    openai_api_key=os.getenv("YOUR_API_KEY"),
    openai_api_base=os.getenv("YOUR_API_UIL")
)
    
class TaskType(str, Enum):
    PQA = "PQA"
    # LABQA = "LABQA"
    ORD = "ORD"
    ERR = "ERR"
    REA_ERR = "REA-ERR"
    GEN = "GEN"
    REA_GEN = "REA-GEN"
    

class RouteQuery(BaseModel):
    """Routes the user query to the correct BioProBench task."""
    task_type: TaskType = Field(
        description="The type of task to perform based on the user query."
    )

# Initialize a powerful LLM for the planner
structured_llm_planner = llm_planner.with_structured_output(RouteQuery)

planner_prompt = ChatPromptTemplate.from_messages([
    ("system", """You are an expert at understanding and routing queries related to biological experiment protocols.
Your goal is to classify the user's query into one of the following six categories from the BioProBench benchmark:
1.  **PQA (Protocol Question Answering):** The user is asking a specific question about a protocol, such as "What concentration of PBS is used in this protocol?" or "How long should the incubation step be?".
2.  **ORD (Step Ordering):** The user provides a set of protocol steps and asks for them to be put in the correct logical order.
3.  **ERR (Error Correction):** The user provides a protocol step or a full protocol that contains an error and asks for it to be identified and corrected.
4.  **REA-ERR (Reasoning-based Error Correction):** A request to evaluate a step's validity with detailed reasoning about operations, reagents, and parameters.
5.  **GEN (Protocol Generation):** A request to generate a new protocol from a high-level instruction.
6.  **REA-GEN (Reasoning-based Generation):** A request to generate a protocol that explicitly includes a chain-of-thought reasoning process.

Classify the query into one of these six `task_type`s."""),
    # We will extract the relevant text from the sample to send to the planner
    ("human", "{query_for_planner}")
])
# **LABQA (LabBench Question Answering):** The user is asking a specific question about a protocol based on [distractors], such as "What steps could improve this result?" or "What could you do to improve to improve the number of embryoid bodies?".

def prepare_inputs_node(state: GraphState):
    """
    Create clean and standardized inputs from the original sample based on task_type 
    for subsequent retrieve and generate nodes.
    """
    task_type = state["task_type"]
    sample = state["sample"]
    
    # Initialize to avoid UnboundLocalError
    retrieval_query = ""
    generation_input = {}

    # --- Unified input construction logic ---
    if task_type == 'PQA':
        question = sample.get('question', '')
        choices_key = 'choices'
        choices = sample.get(choices_key, [])
        
        retrieval_query = f"{question}\nChoices:{choices}"
        generation_input = {"question": question, "choices": str(choices)}

    elif task_type in ['ERR', 'REA-ERR']:
        text_to_process = sample.get('corrupted_text') or sample.get('corrected_text', '')
        corrupted_text = sample.get('corrupted_text','')
        corrected_text = sample.get('corrected_text', '')
        is_correct = sample.get('is_correct', '')
        context = sample.get('context', '')
        
        retrieval_query = f"{text_to_process}\nContext:{context}"
        generation_input = {"corrupted_text": corrupted_text,"corrected_text":corrected_text, "context": context, "is_correct": is_correct}

    elif task_type == 'ORD':
        question = sample.get('question', '')
        wrong_steps = sample.get('wrong_steps', [])
        
        retrieval_query = f"{question}\nThe steps are:{wrong_steps}"
        generation_input = {"question": question, "wrong_steps": wrong_steps}

    elif task_type in ['GEN', 'REA-GEN']:
        system_prompt = sample.get('system_prompt', '')
        instruction = sample.get('instruction', '')
        user_input = sample.get('input', '')
        
        retrieval_query = f"{instruction}\n{user_input}"
        generation_input = {"system_prompt": system_prompt, "instruction": instruction, "input": user_input}
    
    elif  task_type == 'LABQA':
        question = sample.get('question', '')
        ideal = sample.get('ideal','')
        distractors = sample.get('distractors','')
        protocol = sample.get('protocol', '')

        retrieval_query = f"{question}\n{ideal}\n{distractors}\n{protocol}"
        generation_input = {"question": question, "ideal": ideal, "distractors": distractors, "protocol": protocol}
        
    else:
        print(f"Warning: Unhandled task_type '{task_type}' in prepare_inputs_node.")
        retrieval_query = str(sample)
        generation_input = {"raw_input": str(sample)}

    return {
        "retrieval_query": retrieval_query,
        "generation_input": generation_input
    }

planner = planner_prompt | structured_llm_planner


# --- 2b. The Retriever ---
class CustomFaissRetriever(BaseRetriever):
    """
    A custom LangChain Retriever that uses a Faiss index and custom embeddings obtained via an API.
    """
    index: faiss.Index
    text_chunks: List[str]
    metadata_list: List[Dict[str, Any]]
    bm25: BM25Okapi
    k: int = 3
    mode: str = "faiss"
    recall_k: int = 50
    alpha: float = 0.6
    beta: float = 0.4
    
    embedding_cache: Dict[str, List[float]] = {}

    def _get_embedding(self, text: str) -> List[float]:
        """
        Get the embedding of the text by calling an API.
        """
        if text in self.embedding_cache:
            return self.embedding_cache[text]
        
        payload = {
            "model": "qwen3-embedding-8b",
            "input": text,
            "encoding_format": "float"
        }
        headers = {
            "Authorization": f"Bearer {os.getenv('BAIDU_API_KEY')}",
            "Content-Type": "application/json"
        }
        
        try:
            response = requests.post(os.getenv("BAIDU_EMB_URL"), json=payload, headers=headers)
            response.raise_for_status()
            embedding = response.json()["data"][0]["embedding"]
            self.embedding_cache[text] = embedding
            return embedding
        except requests.exceptions.RequestException as e:
            print(f"Error calling embedding API: {e}")
            return []
        except (KeyError, IndexError) as e:
            print(f"Error parsing embedding API response: {e}")
            return []

    def _search_faiss(self, query: str) -> Dict[int, float]:
        """Performs a FAISS search and returns a dictionary of {index: score}."""
        query_embedding_list = self._get_embedding(query)
        if not query_embedding_list:
            print("Failed to get query embedding.")
            return {}
        
        query_embedding = np.array([query_embedding_list]).astype('float32')
        distances, indices = self.index.search(query_embedding, self.k)
        
        results = {}
        for i, dist in zip(indices[0], distances[0]):
            if i != -1:
                relevance_score = np.exp(-dist / 2) * 100
                results[i] = relevance_score
        return results
    
    def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
        """
        The core implementation method of the LangChain Retriever.
        """
        if self.mode == "faiss":
            query_embedding_list = self._get_embedding(query)
            if not query_embedding_list:
                return []

            query_embedding = np.array([query_embedding_list]).astype('float32')
            distances, indices = self.index.search(query_embedding, self.k)

            docs = []
            for i, dist in zip(indices[0], distances[0]):
                if i != -1:
                    relevance_score = round(np.exp(-dist / 2) * 100)
                    doc_metadata = self.metadata_list[i].copy()
                    doc_metadata['relevance_score'] = relevance_score
                    doc = Document(page_content=self.text_chunks[i], metadata=doc_metadata)
                    docs.append(doc)

        elif self.mode == "hybrid":
            query_embedding_list = self._get_embedding(query)
            if not query_embedding_list:
                return []
            
            query_embedding = np.array([query_embedding_list]).astype('float32')
            # Search for recall_k candidates
            distances, candidate_indices = self.index.search(query_embedding, self.recall_k)

            faiss_results = {}
            for i, dist in zip(candidate_indices[0], distances[0]):
                if i != -1:
                    faiss_results[i] = np.exp(-dist / 2) * 100

            if not faiss_results:
                return []
            
            # Calculate BM25 scores only for recalled candidate documents

            # a. Build a corpus that only contains candidate documents
            rerank_corpus_indices = list(faiss_results.keys())
            rerank_corpus = [self.text_chunks[i] for i in rerank_corpus_indices]
            tokenized_rerank_corpus = [doc.split() for doc in rerank_corpus]
            
            # b. Create a temporary BM25 instance on this small corpus
            rerank_bm25 = BM25Okapi(tokenized_rerank_corpus)
            
            # c. Calculate BM25 scores
            tokenized_query = query.split()
            rerank_bm25_scores = rerank_bm25.get_scores(tokenized_query)

            # d. Normalize BM25 scores
            max_bm25_score = max(rerank_bm25_scores) + 1e-6
            normalized_bm25_scores = {
                original_idx: (score / max_bm25_score) * 100
                for original_idx, score in zip(rerank_corpus_indices, rerank_bm25_scores)
            }

            # === Fuse & Sort ===
            combined_scores = {}
            for idx in rerank_corpus_indices:
                score_faiss = faiss_results.get(idx, 0.0)
                score_bm25 = normalized_bm25_scores.get(idx, 0.0)
                combined_scores[idx] = self.alpha * score_faiss + self.beta * score_bm25

            # Sort and select the final top k results
            ranked_indices = sorted(combined_scores.keys(), key=lambda x: combined_scores[x], reverse=True)[:self.k]

            # === Construct Final Documents ===
            docs = []
            for i in ranked_indices:
                doc_metadata = self.metadata_list[i].copy()
                doc_metadata['relevance_score'] = round(combined_scores[i])
                doc = Document(page_content=self.text_chunks[i], metadata=doc_metadata)
                docs.append(doc)

        else:
            raise ValueError(f"Unknown retrieval mode: {self.mode}")
        
        return docs

def get_retriever(k: int = 3) -> CustomFaissRetriever:
    """
    Load data and index files, and initialize a custom Retriever.
    """
    print("--- Initializing Custom FAISS Retriever... ---")

    index_path = config.FAISS_INDEX_PATH
    chunks_json_path = config.CHUNKS_JSON_PATH
    
    # Check if files exist
    if not os.path.exists(index_path) or not os.path.exists(chunks_json_path):
        raise FileNotFoundError(
            f"Required files not found. Please ensure '{index_path}' and '{chunks_json_path}' exist."
        )

    # Load Faiss index and chunk data
    try:
        index = faiss.read_index(index_path)
        with open(chunks_json_path, "r", encoding="utf-8") as f:
            raw_data = json.load(f)
        
        text_chunks = [item["page_content"] for item in raw_data]
        metadata_list = [item["metadata"] for item in raw_data]

        print(f"--- Successfully loaded Faiss index with {index.ntotal} vectors and {len(text_chunks)} chunks. ---")
        
         # --- Initialize BM25 ---
        tokenized_corpus = [doc.split() for doc in text_chunks]
        bm25 = BM25Okapi(tokenized_corpus)

    except Exception as e:
        print(f"Error loading retriever data: {e}")
        raise

    # Instantiate and return the retriever
    return CustomFaissRetriever(
        index=index,
        text_chunks=text_chunks,
        metadata_list=metadata_list,
        bm25=bm25,
        k=k,
        mode="hybrid"
    )

# --- 2c. The LLMs for Generation ---
gemini_llm = ChatOpenAI(
    model=config.GEMINI_STYLE_MODEL_NAME, 
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    openai_api_base=os.getenv("OPENAI_BASE_URL")
)
claude_llm = ChatOpenAI(
    model=config.CLAUDE_STYLE_MODEL_NAME, 
    openai_api_key=os.getenv("OPENAI_API_KEY"),
    openai_api_base=os.getenv("OPENAI_BASE_URL")
)

# --- Step 3: Define the Nodes for the Graph ---
def rule_based_planner_node(state: GraphState):
    """
    Replaces the LLM planner with a fast and 100% accurate rule-based classifier
    for structured benchmark data.
    """
    sample = state["sample"]
    task_type = "UNKNOWN" # default

    if 'instruction' in sample:
        task_type = 'GEN'
    elif 'corrupted_text' in sample:
        task_type = 'ERR'
    elif 'wrong_steps' in sample:
        task_type = 'ORD'
    elif 'question' in sample:
        if 'distractors' in sample:
            task_type = 'LABQA'
        else:
            task_type = 'PQA'
    
    print(f"--- Rule-based planner classified task as: {task_type} ---")
    return {"task_type": task_type}

def planner_node(state: GraphState):
    """Determines the type of task to perform based on the retrieval_query content."""
    # print("--- 1. EXECUTING PLANNER NODE ---")
    query_for_planner = state["retrieval_query"]
    
    route = planner.invoke({"query_for_planner": query_for_planner})
    # print(f"--- Planner classified task as: {route.task_type} ---")
    return {"task_type": route.task_type}

default_retriever = get_retriever(k=5)

def retrieve_node(state: GraphState):
    """Retrieves relevant documents from the knowledge base."""
    # print("--- 2. EXECUTING RETRIEVE NODE ---")
    task_type = state["task_type"]
    k = config.TASK_CONFIG.get(task_type, {}).get("retriever_k", config.RETRIEVER_K_DEFAULT)
    default_retriever.k = k
    
    retriever_to_use = default_retriever
    retrieval_query = state["retrieval_query"]

    documents = retriever_to_use.invoke(retrieval_query)

    return {"documents": documents}

def generation_node(state: GraphState):
    """Generate answers using the correct context format and LLM calling method."""
    task_type = state["task_type"]
    generation_input = state["generation_input"]
    documents = state.get("documents", [])
    
    # print(f"--- 3. EXECUTING GENERATION NODE ({task_type}) ---")
    
    # --- 1. Construct retrieved context with relevance scores and polished format ---
    if documents:
        # Sort by relevance score in descending order to ensure the most important information comes first
        documents.sort(key=lambda doc: doc.metadata.get('relevance_score', 0), reverse=True)
        
        context_parts = []
        for doc in documents:
            score = doc.metadata.get('relevance_score', 'N/A')
            title = doc.metadata.get('title', 'N/A')
            method = doc.metadata.get('method', 'N/A')
            content = doc.page_content
            context_parts.append(
                f"[Relevance Score: {score}%] Source: {title} (Method: {method})\n"
                f"Protocol Content: {content}"
            )
        retrieved_context_str = "\n---\n".join(context_parts)
    else:
        retrieved_context_str = "No relevant documents were found in the knowledge base."
    
    # --- 2. Build the final, structured prompt for different tasks ---
    master_prompt_template = """You are a world-class expert in biological experimental protocols and reasoning. Your task is to critically integrate the [Retrieved Context] with your extensive knowledge of biological sciences to expertly address the [Specific Task]. 

Please consider [retrieved context] as a reference, especially if its [relevance score] is high. However, this information may only be a relevant clue, not a direct or complete answer. 
---
{main_content}
---
"""
    
    task_specific_prompt = generate_user_prompt(generation_input, task_type)
    main_content = f"""[Retrieved Context]
{retrieved_context_str}

[Specific Task]
** Please answer strictly according to the instructions and requirements in [Specific Task]. **
{task_specific_prompt}
"""
    final_prompt = master_prompt_template.format(main_content=main_content)

    task_type = state["task_type"]
    model_name = config.TASK_CONFIG.get(task_type, {}).get("llm", config.GEMINI_STYLE_MODEL_NAME) # default
    
    llm_map = {
        config.GEMINI_STYLE_MODEL_NAME: gemini_llm,
        config.CLAUDE_STYLE_MODEL_NAME: claude_llm,
    }
    llm_to_use = llm_map[model_name]
    
    response = llm_to_use.invoke(final_prompt)
    return {"generation": response.content}

# --- Step 4: Define the Conditional Edges ---

def should_retrieve(state: GraphState):
    """Decide whether to retrieve based on the task config."""
    task_type = state["task_type"]
    if config.TASK_CONFIG.get(task_type, {}).get("requires_retrieval", False):
        return "retrieve"
    return "generate"

# --- Step 5: Build and Compile the Graph ---

# Define the workflow
workflow = StateGraph(GraphState)

# Add the nodes
# workflow.add_node("planner", planner_node)
workflow.add_node("planner", rule_based_planner_node)
workflow.add_node("prepare_inputs", prepare_inputs_node)
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("generate", generation_node)

# Set the entry point
workflow.set_entry_point("planner")

# Add the edges
workflow.add_edge("planner", "prepare_inputs")
workflow.add_conditional_edges(
    "prepare_inputs",
    should_retrieve,
    {
        "retrieve": "retrieve",
        "generate": "generate",
    }
)
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

memory = SqliteSaver.from_conn_string(":memory:")
app = workflow.compile(checkpointer=memory)

# --- Step 6: Run the Agent ---

def run_proagent(sample: Dict[str, Any], app):
    thread_id = sample.get("id", f"thread-{hash(str(sample))}")
    config = {"configurable": {"thread_id": thread_id}}
    
    inputs = {"sample": sample}
    
    final_state = app.invoke(inputs, config=config)
    
    print("\n--- Agent Run Complete ---")
    print("Final Answer:")
    print(final_state["generation"])
    return final_state

if __name__ == "__main__":
    with SqliteSaver.from_conn_string(":memory:") as memory:
        # The app is compiled inside the block, using the active checkpointer
        app = workflow.compile(checkpointer=memory)

        # This example demonstrates using ProAgent to complete the task of generating a protocol (GEN)
        gen_sample = {
            "system_prompt": "As an expert in molecular biology protocols, provide clear and concise step-by-step instructions for experimental procedures.",
            "instruction": "Please describe the protocol in a flat list format (using only 1., 2., 3. numbers). Include only the steps, not a rationale or materials list. Use concise language and maintain a chronological order.",
            "input": "How to prepare and normalize cell extracts for kinase activity assays?"
        }
        
        run_proagent(gen_sample, app)
        