"""
Task utilities for document QA system
"""

import os
import re
import json
import logging
from typing import Dict, List, Optional

from langchain.llms.base import BaseLLM

from document_qa.core.state import DocumentQAState
from document_qa.core.document_manager import DocumentManager
from document_qa.utils.llm_utils import get_llm_response, extract_json_from_text

logger = logging.getLogger(__name__)

def task_decomposer(query: str, llm: Optional[BaseLLM], state: DocumentQAState, documents: List[str] = None, doc_mode: str = "paths", retrieval_mode: str = "vector") -> Dict:
    """Decomposes a complex query into a series of executable steps with requirements.
    
    Args:
        query: The user query
        llm: Language model
        state: Shared state
        documents: List of document references (paths or contents)
        doc_mode: Mode for document reference - "paths" or "contents"
        retrieval_mode: Method for retrieval - "llm" or "vector"
        
    Returns:
        Dict containing decomposed steps
    """
    
    # Include information about available documents
    document_info = ""
    doc_content = ""

    if doc_mode == "paths":
        if documents and len(documents) > 0:
            document_info = f"\nAvailable documents: {', '.join([os.path.basename(path) for path in documents])}"
    elif doc_mode == "contents":
        if documents and len(documents) > 0:    
            document_info = ""
        try:
            doc_content = "DOCUMENT CONTENT: " + documents[0]
        except:
            doc_content = ""
                    
        
    prompt = f"""
    As an expert research assistant, analyze and break down this query into clear executable steps:
    
    ORIGINAL QUERY: "{query}"{document_info}

    {doc_content}
    
    For each step, provide:
    1. A clear description of what needs to be done
    2. Whether document retrieval is needed for this step
    3. Specific keywords to search for (if document retrieval is needed)
    4. Dependencies on previous steps (if any)
    5. Use as few steps as possible; ideally, complete it in a single step. If it can be done in one step, there's no need for a second step to summarize the first, and just use the original query.

    
    Important: For each step, consider if it depends on results from previous steps. If it does, list those dependencies.
    CONSTRAINT: A step can ONLY depend on steps with LOWER step numbers (i.e., steps that come before it). 
    For example, Step 3 can depend on Steps 1 and 2, but not on Steps 4, 5, etc.
    
    THE LAST STEP MUST BE A SYNTHESIS STEP that integrates results from previous steps to provide the final answer. 
    Also, the last step description should concat the ORIGINAL QUERY in the last sentence.
    
    Output in JSON format with the following structure:
    {{
        "steps": [
            {{
                "step_number": 1,
                "description": "...",
                "requires_document": true/false,
                "search_keywords": ["keyword1", "keyword2", ...],
                "reasoning": "...",
                "depends_on_steps": [list of step numbers this step depends on]
            }},
            ...
        ]
    }}
    """
    
    try:
        response = get_llm_response(prompt, model="qwen-long")#qwen2-72b-instruct
        result = extract_json_from_text(response)
        
        # Ensure result has the expected structure
        if not isinstance(result, dict) or "steps" not in result:
            state.add_thought("planner", "Task decomposition failed due to incorrect result format. I need to try again.")
            # Fallback to a simple two-step plan
            result = {
                "steps": [
                    {
                        "step_number": 1,
                        "description": f"Search for information related to: {query}",
                        "requires_document": True,
                        "search_keywords": query.split(),
                        "reasoning": "Need to find relevant information from documents.",
                        "depends_on_steps": []
                    },
                    {
                        "step_number": 2,
                        "description": f"Synthesize answer based on found information",
                        "requires_document": False,
                        "reasoning": "Need to provide a coherent answer based on retrieved information.",
                        "depends_on_steps": [1]
                    }
                ]
            }
        
        # Add document_paths to each step if available
        if doc_mode == "paths":
            if documents:
                for step in result["steps"]:
                    if step.get("requires_document", False):
                        step["document_paths"] = documents
                        step["retrieval_mode"] = retrieval_mode
        elif doc_mode == "contents":
            if documents:
                for step in result["steps"]:
                    if step.get("requires_document", False):
                        step["document_contents"] = documents
                        step["retrieval_mode"] = retrieval_mode
        
        # Enforce the constraint that steps can only depend on previous steps
        for step in result["steps"]:
            step_num = step.get("step_number", 0)
            # Filter out any dependency on steps with higher or equal numbers
            if "depends_on_steps" in step:
                step["depends_on_steps"] = [dep for dep in step["depends_on_steps"] if dep < step_num]
        
        return result
    
    except Exception as e:
        state.add_thought("planner", f"Error in task decomposition: {str(e)}. I will use a simplified plan.")
        # Simple fallback plan
        fallback = {
            "steps": [
                {
                    "step_number": 1,
                    "description": f"Search for information related to: {query}",
                    "requires_document": True,
                    "search_keywords": query.split(),
                    "reasoning": "Need to find relevant information from documents.",
                    "depends_on_steps": []
                },
                {
                    "step_number": 2,
                    "description": f"Synthesize answer based on found information",
                    "requires_document": False,
                    "reasoning": "Need to provide a coherent answer based on retrieved information.",
                    "depends_on_steps": [1]
                }
            ]
        }
        
        # Add document_paths to steps if available
        if documents:
            for step in fallback["steps"]:
                if step.get("requires_document", False):
                    if doc_mode == "paths":
                        step["document_paths"] = documents
                        step["retrieval_mode"] = retrieval_mode
                    elif doc_mode == "contents":
                        step["document_contents"] = documents
                        step["retrieval_mode"] = retrieval_mode
        
        return fallback

def optimize_original_steps(steps: List[Dict]) -> List[Dict]:
    """Optimize the steps from the planner by ensuring proper step numbering and dependencies."""
    # Ensure step numbers are sequential and start from 1
    for i, step in enumerate(steps, start=1):
        step["step_number"] = i
    
    # Ensure dependencies reference valid previous steps
    for step in steps:
        step_num = step.get("step_number", 0)
        if "depends_on_steps" in step:
            # Filter out any dependency on steps with higher or equal numbers
            step["depends_on_steps"] = [dep for dep in step["depends_on_steps"] if dep < step_num]
    
    # For the last step, ensure it depends on at least one previous step
    if len(steps) > 1:
        last_step = steps[-1]
        if not last_step.get("depends_on_steps", []):
            # Set dependency on all or most relevant previous steps
            all_prev_steps = list(range(1, last_step["step_number"]))
            last_step["depends_on_steps"] = all_prev_steps
    
    return steps

def data_structure_selector(step_info: Dict, llm: Optional[BaseLLM], state: DocumentQAState) -> str:
    question = step_info.get("description", "")
    step_num = step_info.get("step_number", state.current_step)

    structures = ["Text Description", "Tree", "Table", "Graph"]
    prompt = (
    "This is a data structure selection task. Based on the given `question`, choose the most suitable data structure "
    "to answer the question. You can choose from the following options: \n")
    for structure in structures:
        prompt += f"- {structure}\n"

    prompt += ("Your answer should be concise and to the point. Return your answer in the following format directly: {answer: data structure}.\n\n"
        f"The question is: {question}"
    )

    ans = get_llm_response(prompt, model="qwen2-72b-instruct")

    ds = None
    match = re.match(r'\{answer:\s*([a-zA-Z\s]+)\}', ans)
    
    if match:
        ds = match.group(1).strip()

    state.data_structure[step_num] = ds
    state.add_thought("structure",f"Structure selection result: {ds}")
    return ds



def count_content_length(text):
    # Count Chinese characters (assuming Chinese characters are in Unicode ranges)
    chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff')
    
    # For non-Chinese parts, count words (splitting by whitespace)
    non_chinese_text = ''.join(' ' if '\u4e00' <= char <= '\u9fff' else char for char in text)
    english_words = len(non_chinese_text.split())
    
    return english_words, chinese_chars

def context_retriever(step_info: Dict, doc_manager: DocumentManager, state: DocumentQAState, doc_mode: str = "paths", retrieval_mode: str = "llm") -> str:
    """Retrieves relevant document contexts based on step requirements.
    
    Args:
        step_info: Information about the current step
        doc_manager: Document manager for retrieval operations
        state: The shared state object
        doc_mode: Mode for document reference - "paths" or "contents"
        retrieval_mode: Method for retrieval - "llm" for LLM-based retrieval or "vector" for vector-based retrieval
        
    Returns:
        String containing the retrieved context
    """
    
    step_num = step_info.get("step_number", state.current_step)
    
    # Get retrieval_mode from step_info if available, otherwise use the provided default
    step_retrieval_mode = step_info.get("retrieval_mode", retrieval_mode)
    
    if not step_info.get("requires_document", False):
        state.add_thought("retriever", f"Step {step_num} does not require document retrieval.")
        return "Document retrieval not needed for this step."
    
    keywords = step_info.get("search_keywords", [])
    description = step_info.get("description", "")
    
    if not keywords and not description:
        return "No search keywords or description available. Please specify what to search for."
    
    # Get document paths or contents from step_info, if available
    if doc_mode == "paths":
        documents = step_info.get("document_paths", None)
    elif doc_mode == "contents":
        documents = step_info.get("document_contents", None)
    
    # Build enhanced search query from description and keywords
    if keywords:
        search_query = f"{description} [SEP] {' '.join(keywords)}"
    else:
        search_query = description
    
    state.add_thought("retriever", f"I will use the following query for search: \"{search_query}\"")
    # state.add_thought("retriever", f"Using retrieval mode: {step_retrieval_mode}")

    # Vector-based retrieval using document manager's FAISS index
    if step_retrieval_mode == "vector" and doc_mode == "paths":
        # Get top_k parameter from step_info or use default
        top_k = step_info.get("top_k", 5)

        docs = doc_manager.retrieve_relevant_documents(search_query, k=top_k, document_paths=documents)
        
        if not docs:
            state.add_thought("retriever", "Unfortunately, no relevant documents were found. We may need to adjust the search keywords.")
            return "No relevant documents found. Consider adjusting search keywords."
        
        # Store the retrieved documents in state
        doc_texts = []
        for i, doc in enumerate(docs):
            source = doc.metadata.get('source', 'unknown source')
            chunk_id = doc.metadata.get('chunk_id', i)
            doc_texts.append(f"Document Chunk {chunk_id} from {source}:\n{doc.page_content}")
        
        result = "\n\n".join(doc_texts)
    
    # LLM-based retrieval directly from document contents
    elif step_retrieval_mode == "llm":
        if doc_mode == "contents":
            # Use document contents directly for LLM-based retrieval
            if "retrieve_attempts" in step_info and step_info["retrieve_attempts"] < 2:
                retrieve_content = documents[0]
            else:
                retrieve_content = documents[1] if len(documents) > 1 else documents[0]
            
            # Get counts
            english_words, chinese_chars = count_content_length(retrieve_content)
            
            # Check if content exceeds limits
            if english_words + chinese_chars <= 50000:
                result = retrieve_content
                state.retrieved_docs[step_num] = result
                state.retrieved_context[step_num] = result
                return result
            
            # If content needs to be truncated, use LLM to extract relevant parts
            retrieve_prompt = f"""
            - Each fragment must:
            - Retain **complete semantic meaning** (e.g., full paragraphs, full tables, bullet lists, or code blocks).  
            - **Preserve structured content** (do not cut off tables, item lists, or inline references).  
            - Come from a document that clearly contains information relevant to the question.

            - Limitations:
            - Each fragment should be **no longer than 500 tokens** (truncate carefully if necessary, without breaking structure).  
            - Return **at most 3 fragments**, prioritized by **most relevant first**.  
            - If no content is relevant, return nothing.

            ---

            **Output format** (one per line, sorted by relevance):

            ```
            <filename>: <relevant content>
            <filename>: <relevant content>
            ...
            ```

            - Do not summarize, generate answers, or include explanations.  
            - Only output the selected fragments in the format above.

            ---

            **Question**:  
            {search_query}

            **Documents**:  
            {retrieve_content}
            """
            result = get_llm_response(retrieve_prompt, model="qwen-long")
        
        elif doc_mode == "paths":
            # For path mode with LLM retrieval, first load document contents
            loaded_contents = []
            for doc_path in documents:
                try:
                    # Load document content using doc_manager
                    content = doc_manager.load_document_content(doc_path)
                    loaded_contents.append(f"Content from {doc_path}:\n{content}")
                except Exception as e:
                    state.add_thought("retriever", f"Error loading document {doc_path}: {str(e)}")
            
            if not loaded_contents:
                return "Could not load contents from the specified document paths."
            
            # Combine all loaded contents
            combined_content = "\n\n".join(loaded_contents)
            
            # Use LLM to extract relevant parts from the loaded content
            retrieve_prompt = f"""
            Find the most relevant content from these documents for the following question:

            **Question**:  
            {search_query}

            **Documents**:  
            {combined_content}

            Return only the most relevant segments, in order of relevance.
            """
            result = get_llm_response(retrieve_prompt, model="qwen-long")
    
    else:
        # Unsupported combination of retrieval_mode and doc_mode
        state.add_thought("retriever", f"Unsupported combination of retrieval_mode={step_retrieval_mode} and doc_mode={doc_mode}")
        return f"Unsupported retrieval configuration: retrieval_mode={step_retrieval_mode}, doc_mode={doc_mode}"
    
    state.retrieved_docs[step_num] = result
    state.retrieved_context[step_num] = result

    state.add_thought("retriever", result)
    
    return result

def information_extractor(step_info: Dict, llm: Optional[BaseLLM], state: DocumentQAState) -> str:
    """Extracts relevant information from retrieved documents for the current step."""
    
    step_num = step_info.get("step_number", state.current_step)

    structure = state.data_structure.get(step_num, "Text Description")
    docs = state.retrieved_docs.get(step_num, "")
    
    if not docs or docs == "Document retrieval not needed for this step.":
        state.add_thought("extractor", f"Step {step_num} has no document content available for extraction.")
        return "No document content available for extraction."
    
    # state.add_thought("extractor", f"开始从检索到的文档中提取与步骤 {step_num} 相关的信息...")
    
    prompt = f"""Extract the most relevant information from these document information and transform into the following structure:
    
    STRUCTURE: {structure}
    
    TASK: {step_info['description']}
    
    DOCUMENT INFORMATION:
    {docs}
    
    EXTRACTION GUIDELINES:
    1. Focus specifically on information that directly addresses the task
    2. Extract key facts, figures, quotes, and findings
    3. Maintain accuracy - don't add information not present in the documents
    4. Note any contradictions or uncertainties in the documents
    5. Organize the extracted information logically
    6. If "STRUCTURE" is a "Graph" or a "Tree," return a tuple of two or three elements.
        
    EXTRACTED INFORMATION:
    """
    
    extracted_info = get_llm_response(prompt, model="gpt-4o-mini")
    state.extracted_info[step_num] = extracted_info

    state.add_thought("extractor", extracted_info)
    
    # state.add_thought("extractor", f"已成功从文档中提取出与步骤 {step_num} 相关的重要信息。")
    return extracted_info

def information_verifier(step_info: Dict, llm: Optional[BaseLLM], state: DocumentQAState) -> Dict:
    """Verifies extracted information against verification criteria."""
    
    step_num = step_info.get("step_number", state.current_step)
    extracted_info = state.extracted_info.get(step_num, "")
    criteria = step_info.get("verification_criteria", "")
    
    if not extracted_info:
        state.add_thought("verifier", f"No extracted information to verify for step {step_num}")
        return {
            "verification_passed": False,
            "completeness": 0,
            "relevance": 0,
            "accuracy": 0,
            "needs_refinement": True,
            "next_action": "search",  # Return to search step
            "refinement_suggestions": "No information was extracted. Try different search keywords."
        }
    
    prompt = f"""Verify the following extracted information against the given criteria:
    
    TASK: {step_info['description']}
    
    VERIFICATION CRITERIA: {criteria}
    
    EXTRACTED INFORMATION:
    {extracted_info}
    
    RETRIEVED CONTEXT:
    {state.retrieved_context.get(step_num, "No context available")}
    
    Evaluate the information based on these dimensions:
    - Completeness (1-5): Does it provide all the necessary information for the task?
    - Relevance (1-5): How relevant is the information to the specific task?
    - Accuracy (1-5): Based on internal consistency, does the information seem accurate?
    - Does the information need refinement? (Yes/No)
    - Is the issue with the extracted information or with the retrieved context?
    - What specific improvements are needed?
    
    Respond in JSON format:
    {{
        "verification_passed": true/false,
        "completeness": 1-5,
        "relevance": 1-5,
        "accuracy": 1-5,
        "needs_refinement": true/false,
        "issue_source": "extraction" or "retrieval" or "none",
        "next_action": "search" or "extract" or "continue",
        "refinement_suggestions": "Specific suggestions for improvement if needed"
    }}
    """
    
    # Use llm_main instead of direct LLM call
    response = get_llm_response(prompt, model="qwen2-72b-instruct")
    
    try:
        # Extract JSON from response
        json_match = re.search(r'({.*})', response.replace('\n', ' '), re.DOTALL)
        if json_match:
            verification_result = json.loads(json_match.group(1))
        else:
            # Fallback if JSON parsing fails
            state.add_thought("verifier", "Could not parse verification response as JSON")
            verification_result = {
                "verification_passed": False,
                "completeness": 0,
                "relevance": 0,
                "accuracy": 0,
                "needs_refinement": True,
                "issue_source": "extraction",
                "next_action": "extract",
                "refinement_suggestions": "Unable to parse verification results. Please retry."
            }
    
    except Exception as e:
        state.add_thought("verifier", f"Error parsing verification response: {str(e)}")
        verification_result = {
            "verification_passed": False,
            "completeness": 0, 
            "relevance": 0,
            "accuracy": 0,
            "needs_refinement": True,
            "issue_source": "extraction",
            "next_action": "extract",
            "refinement_suggestions": f"Error in verification process: {str(e)}"
        }
    
    # Ensure next_action is set based on verification results
    if "next_action" not in verification_result:
        if verification_result.get("verification_passed", False):
            verification_result["next_action"] = "continue"
        elif verification_result.get("issue_source", "") == "retrieval":
            verification_result["next_action"] = "search"
        else:
            verification_result["next_action"] = "extract"
    
    state.verification_results[step_num] = verification_result
    
    # Add thought based on verification result with more natural language
    if verification_result.get("verification_passed", False):
        state.add_thought("verifier", f"Information verification passed, continue analysis")
    else:
        issue_source = verification_result.get("issue_source", "extraction")
        next_action = verification_result.get("next_action", "extract")
        
        if next_action == "retrieve":
            state.add_thought("verifier", 
                             f"Step {step_num} information verification failed. The retrieved document content seems insufficiently relevant or complete." +
                             f"Let's try using different keywords to search for more relevant documents." +
                             f"Suggestion: {verification_result.get('refinement_suggestions', 'Try more precise search keywords')}")
        elif next_action == "extract":
            state.add_thought("verifier", 
                             f"Step {step_num} information verification failed. Although the retrieved document content may be relevant, our information extraction is not accurate enough." +
                             f"Let's reanalyze the document content and extract more accurate information." +
                             f"Suggestion: {verification_result.get('refinement_suggestions', 'More carefully analyze the document content')}")
        else:
            state.add_thought("verifier", 
                             f"Step {step_num} information verification failed, but the reason is unclear." +
                             f"Suggestion: {verification_result.get('refinement_suggestions', 'Unknown reason')}")
    
    return verification_result

def information_refiner(step_info: Dict, llm: Optional[BaseLLM], state: DocumentQAState) -> str:
    """Refines extracted information based on verification feedback."""
    
    step_num = step_info.get("step_number", state.current_step)
    extracted_info = state.extracted_info.get(step_num, "")
    verification = state.verification_results.get(step_num, {})
    docs = state.retrieved_docs.get(step_num, "")
    
    if not verification.get("needs_refinement", False):
        print("refinement", f"Step {step_num} information is already good, no further refinement needed.")
        return extracted_info
    
    state.add_thought("refinement", f"Based on verification feedback, I need to refine step {step_num} information. Let me reanalyze the document content...")
    
    prompt = f"""Refine the previously extracted information based on the verification feedback:
    
    TASK: {step_info['description']}
    
    ORIGINAL EXTRACTED INFORMATION:
    {extracted_info}
    
    VERIFICATION FEEDBACK:
    - Completeness: {verification.get('completeness', 'N/A')}/5
    - Relevance: {verification.get('relevance', 'N/A')}/5
    - Accuracy: {verification.get('accuracy', 'N/A')}/5
    - Refinement needed: {verification.get('refinement_suggestions', 'No specific suggestions')}
    
    SOURCE DOCUMENT CHUNKS:
    {docs}
    
    REFINED INFORMATION (address all the issues mentioned in the feedback):
    """
    
    # Use llm_main instead of direct LLM call
    state.add_thought("refinement", "Analyzing document content, extracting more accurate and complete information...")
    refined_info = get_llm_response(prompt, model="qwen2-72b-instruct")
    state.extracted_info[step_num] = refined_info
    state.add_thought("extractor", refined_info)
    state.add_thought("refinement", f"Step {step_num} information has been refined. Now the information is more accurate, relevant, and complete.")
    return refined_info

def answer_synthesizer(query: str, llm: Optional[BaseLLM], state: DocumentQAState) -> Dict:
    """Synthesizes a final answer from all the extracted and verified information."""
    
    # Compile all verified information
    state.add_thought("synthesizer", "Starting to synthesize the final answer by integrating all collected information...")
    
    step_results = []
    for step in state.steps:
        step_num = step.get("step_number", 0)
        description = step.get("description", "")
        
        # Try to get the step answer first
        step_answer = state.step_answers.get(step_num, None)
        
        # If no step answer, fall back to extracted info
        if not step_answer:
            step_answer = state.extracted_info.get(step_num, "No information extracted")
        
        step_results.append(f"STEP {step_num}: {description}\n{step_answer}")
    
    all_extracted_info = "\n\n".join(step_results)
    
    # Get the thought trail
    thought_trail = state.get_thought_trail()
    
    state.add_thought("synthesizer", "All information has been integrated, now I will analyze these information and form a comprehensive answer...")
    
    prompt = f"""Synthesize a comprehensive answer to the original query based on all the information gathered:
    
    ORIGINAL QUERY: {query}
    
    INFORMATION GATHERED:
    {all_extracted_info}
    
    Craft a well-organized, coherent response that:
    1. Directly answers the original query
    2. Integrates information from all steps
    3. Presents information in a logical flow
    4. Cites specific findings from the steps when relevant
    5. Acknowledges any limitations or uncertainties in the information

    Your answer should be thorough but focused on what's most important to address the query.
    """
    
    final_answer = get_llm_response(prompt, model="qwen2-72b-instruct")
    
    result = {
        "query": query,
        "final_answer": final_answer,
        "step_answers": {step["step_number"]: state.step_answers.get(step["step_number"], 
                                         state.extracted_info.get(step["step_number"], ""))
                    for step in state.steps},
        "thought_process": thought_trail
    }
    
    return result 