import os
import json
import logging
from dotenv import load_dotenv
from typing import List

from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnableConfig

from agents.state import OverallState, SearcherState, BrowserState
from agents.prompts import *
from agents.schemas import *
from agents.openai.config import Configuration
from agents.utils import get_user_question, get_search_results, visit, split_webpage_content

load_dotenv(dotenv_path=".env")

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
SERP_API_KEY = os.getenv("SERP_API_KEY")
JINA_API_KEY = os.getenv("JINA_API_KEY")

def searcher(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general searcher block that generates search queries, conducts searches, and selects search results.
    """
    logger.info("🔍 Starting search process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(
        model=configurable.searcher_model,
        api_key=OPENAI_API_KEY,
        # base_url="...",
        # organization="...",
        # other params...
    )

    if state.get('instruction_state', {}).get('searcher_instructions'):
        experiences = f"**Some experiences that might be userful:**\n{'\n'.join(state.get('instruction_state', {}).get('searcher_instructions', []))}"
    else:
        experiences = ""

    def _get_search_queries():
        structured_llm = llm.with_structured_output(QueryWriter)
        formatted_prompt = query_writer_prompt.format(query_count=configurable.number_of_queries_per_search, 
                                                      current_date=get_current_date(),
                                                      original_question=get_user_question(state['messages']),
                                                      sub_question=state.get('current_sub_question', ''),
                                                      used_search_keywords_and_phrases=state.get('searcher_state', {}).get('used_keywords', []),
                                                      current_summary=state.get('current_summary', ''),
                                                      experiences=experiences
                                                      )
        
        result = structured_llm.invoke(formatted_prompt)
        return result.search_query_list

    def _search_one_query(query: str, search_cache: dict):
        # Check if the query is in the search cache
        if query in search_cache.keys():
            search_results = search_cache[query]
        else:
            api_key = SERP_API_KEY
            search_results = get_search_results(query=query, api_key=api_key)
            # update the search cache
            search_cache[query] = search_results
        
        # Convert raw results to proper format for LLM processing
        if isinstance(search_results, str):
            # Handle error cases
            formatted_results = search_results
        else:
            formatted_results = json.dumps(search_results, indent=4)
        structured_llm = llm.with_structured_output(SelectedSearchResults)
        formatted_prompt = search_result_selection_prompt.format(query=query,
                                                                 original_question=get_user_question(state['messages']),
                                                                 sub_question=state.get('current_sub_question', ''),
                                                                #  current_summary=state['current_summary'],
                                                                 search_results=formatted_results,
                                                                 current_date=get_current_date(),
                                                                 experiences=experiences)
        result = structured_llm.invoke(formatted_prompt)
        return result.selected_results

    search_results = []
    search_count = 0
    query_list = _get_search_queries()
    logger.info(f"🔍 Searching {len(query_list)} queries: {query_list}")
    search_cache = state.get('searcher_state', {}).get('search_cache', {})

    for i, query in enumerate(query_list, 1):
        logger.info(f"🔎 Processing query {i}/{len(query_list)}: {query}")
        search_count += 1
        query_results = _search_one_query(query, search_cache)
        # Ensure results are SearchResult objects
        if query_results:
            search_results.extend(query_results)
    
    # Final safety check - ensure all results are SearchResult objects
    validated_results = []
    for result in search_results:
        if hasattr(result, 'url') and hasattr(result, 'snippet'):
            validated_results.append(result)
        else:
            logger.warning(f"Invalid search result found: {type(result)} - {result}")
    
    searcher_state = SearcherState(
        search_count=search_count,
        used_keywords=query_list,
        search_results=validated_results,
        search_cache=search_cache
    )
    logger.info(f"✅ Search completed! Found {len(search_results)} results")
    return {"searcher_state": searcher_state}

# def parallel_searcher(state: OverallState, config: RunnableConfig) -> OverallState:
# """
# Parallel searcher is a searcher block that uses multiple language models to search for information. (To be implemented)
# """
#     logger.info("🔍 Starting parallel search process...")
#     configurable = Configuration.from_runnable_config(config)
#     llm = ChatOpenAI(
#         model=configurable.searcher_model,
#         api_key=OPENAI_API_KEY,
#     )

def fast_searcher(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a fast searcher block that conducts one search each time.
    """
    logger.info("🔍 Starting fast search process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(
        model=configurable.searcher_model,
        api_key=OPENAI_API_KEY,
    )

    if state.get('instruction_state', {}).get('searcher_instructions'):
        experiences = f"**Some experiences that might be userful:**\n{'\n'.join(state.get('instruction_state', {}).get('searcher_instructions', []))}"
    else:
        experiences = ""

    def _get_search_query():
        structured_llm = llm.with_structured_output(FastQueryWriter)
        formatted_prompt = fast_query_writer_prompt.format(current_date=get_current_date(),
                                                          original_question=get_user_question(state['messages']),
                                                          sub_question=state.get('current_sub_question', ''),
                                                          used_search_keywords_and_phrases=state.get('searcher_state', {}).get('used_keywords', []),
                                                          current_summary=state.get('current_summary', ''),
                                                          experiences=experiences)
        result = structured_llm.invoke(formatted_prompt)
        return result.search_query

    def _search_one_query(query: str, search_cache: dict):
        # Check if the query is in the search cache
        if query in search_cache.keys():
            search_results = search_cache[query]
        else:
            api_key = SERP_API_KEY
            search_results = get_search_results(query=query, api_key=api_key)
            # update the search cache
            search_cache[query] = search_results
        
        # Convert raw results to proper format for LLM processing
        if isinstance(search_results, str):
            # Handle error cases
            formatted_results = search_results
        else:
            formatted_results = json.dumps(search_results, indent=4)
            
        structured_llm = llm.with_structured_output(SelectedSearchResults)
        formatted_prompt = search_result_selection_prompt.format(query=query,
                                                                 original_question=get_user_question(state['messages']),
                                                                 sub_question=state.get('current_sub_question', ''),
                                                                #  current_summary=state['current_summary'],
                                                                 search_results=formatted_results,
                                                                 current_date=get_current_date(),
                                                                 experiences=experiences)
        result = structured_llm.invoke(formatted_prompt)
        return result.selected_results

    search_results = []
    used_keywords = []
    search_count = 0
    search_cache = state.get('searcher_state', {}).get('search_cache', {})
    while len(search_results) == 0:
        search_count += 1
        query = _get_search_query()
        used_keywords.append(query)
        logger.info(f"🔍 Searching {query}")
        search_results = _search_one_query(query, search_cache)

    searcher_state = SearcherState(
        search_count=search_count,
        used_keywords=used_keywords,
        search_results=search_results
    )
    return {"searcher_state": searcher_state}

def browser(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general browser block that visits web pages, extracts relevant information based on topics, and creates references from the collected content.
    """
    logger.info("🌐 Starting web browsing process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(
        model=configurable.browser_model,
        api_key=OPENAI_API_KEY,
    )

    if state.get('instruction_state', {}).get('browser_instructions'):
        experiences = f"**Some experiences that might be userful:**\n{'\n'.join(state.get('instruction_state', {}).get('browser_instructions', []))}"
    else:
        experiences = ""
    
    def _get_url_list():
        structured_llm = llm.with_structured_output(UrlSelection)
        formatted_prompt = url_selection_prompt.format(original_question=get_user_question(state['messages']),
                                                      sub_question=state.get('current_sub_question', ''),
                                                    #   current_summary=state['current_summary'],
                                                      list_of_urls_and_snippets=json.dumps([result.dict() for result in state.get('searcher_state', {}).get('search_results', [])], indent=4),
                                                      current_date=get_current_date(),
                                                      experiences=experiences)
        result = structured_llm.invoke(formatted_prompt)
        return result.selected_urls

    def _extract_information(webpage_content_parts: List[str]):
        # Truncate content if it's too long to avoid context limits
        # truncated_content = truncate_webpage_content(
        #     webpage_content, 
        #     configurable.max_webpage_content_length, 
        #     topic_list
        # )
        
        # # Log truncation if it occurred
        # if len(webpage_content) > len(truncated_content):
        #     logger.info(f"📄 Webpage content truncated: {len(webpage_content)} → {len(truncated_content)} characters")

        structured_llm = llm.with_structured_output(ExtractInformation)
        information_list = []
        should_continue = True
        for i, webpage_content_part in enumerate(webpage_content_parts, 1):
            logger.info(f"📄 Processing part {i}/{len(webpage_content_parts)}")
            webpage_content_part = "Part " + str(i) + "/" + str(len(webpage_content_parts)) + ":\n" + webpage_content_part.strip()
            formatted_prompt = extract_information_prompt.format(webpage_content=webpage_content_part,
                                                                original_question=get_user_question(state['messages']),
                                                                sub_question=state.get('current_sub_question', ''),
                                                                current_date=get_current_date())
            result = structured_llm.invoke(formatted_prompt)
            # Extend instead of append to flatten the list of lists into a single list
            information_list.extend(result.information_list)
            should_continue = result.should_continue
            if not should_continue:
                break
        return information_list
        
        # formatted_prompt = extract_information_prompt.format(webpage_content=truncated_content,
        #                                               list_of_topics=topic_list,
        #                                               current_date=get_current_date())
        # result = structured_llm.invoke(formatted_prompt)
        # return result.information_list
    
    url_list = _get_url_list()
    logger.info(f"🌐 Visiting {len(url_list)} URLs: {url_list}")
    
    reference_list = []
    visit_count = 0
    visit_cache = state.get('browser_state', {}).get('visit_cache', {})
    for i, url in enumerate(url_list, 1):
        logger.info(f"📄 Processing webpage {i}/{len(url_list)}: {url}")
        visit_count += 1
        if url in visit_cache.keys():
            logger.info(f"📄 Cached visit for {url}")
            webpage_content = visit_cache[url]
        else:
            webpage_content = visit(url, api_key=JINA_API_KEY)
            visit_cache[url] = webpage_content
        webpage_content_parts = split_webpage_content(webpage_content, configurable.max_webpage_content_length)
        logger.info(f"📄 Split webpage content into {len(webpage_content_parts)} parts")
        information_list = _extract_information(webpage_content_parts)
        reference = Reference(
            url=url,
            information_list=information_list)
        reference_list.append(reference)

    browser_state = BrowserState(
        visit_count=visit_count,
        visited_urls=url_list,
        found_references=reference_list,
        visit_cache=visit_cache
    )
    logger.info(f"✅ Browser completed! Extracted information from {len(reference_list)} references")
    return {"browser_state": browser_state}

def fast_browser(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a fast browser block that visits web pages, once it gets the information, it will stop visiting immediately.
    """
    logger.info("🌐 Starting fast web browsing process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(
        model=configurable.browser_model,
        api_key=OPENAI_API_KEY,
    )

    if state.get('instruction_state', {}).get('browser_instructions'):
        experiences = f"**Some experiences that might be userful:**\n{'\n'.join(state.get('instruction_state', {}).get('browser_instructions', []))}"
    else:
        experiences = ""

    def _get_url_list():
        structured_llm = llm.with_structured_output(UrlSelection)
        formatted_prompt = url_selection_prompt.format(original_question=get_user_question(state['messages']),
                                                      sub_question=state.get('current_sub_question', ''),
                                                      list_of_urls_and_snippets=json.dumps([result.dict() for result in state.get('searcher_state', {}).get('search_results', [])], indent=4),
                                                      current_date=get_current_date(),
                                                      experiences=experiences)
        result = structured_llm.invoke(formatted_prompt)
        return result.selected_urls

    # def _extract_information(url: str, webpage_content_parts: List[str]):
    #     information = ""
    #     structured_llm = llm.with_structured_output(FastExtractInformation)
    #     for i, webpage_content_part in enumerate(webpage_content_parts, 1):
    #         logger.info(f"📄 Processing part {i}/{len(webpage_content_parts)}")
    #         webpage_content_part = "Part " + str(i) + "/" + str(len(webpage_content_parts)) + ":\n" + webpage_content_part.strip()
    #         formatted_prompt = fast_extract_information_prompt.format(webpage_content_part=webpage_content_part,
    #                                                             original_question=get_user_question(state['messages']),
    #                                                             sub_question=state.get('current_sub_question', ''),
    #                                                             current_date=get_current_date())
    #         result = structured_llm.invoke(formatted_prompt)
    #         try:
    #             if result.information and result.information != "":
    #                 information = result.information
    #             if not result.should_continue:
    #                 break
    #         except:
    #             continue

    #     if information != "":
    #         return information
    #     else:
    #         return None

    def _extract_information(webpage_content_parts: List[str]):
        # Truncate content if it's too long to avoid context limits
        # truncated_content = truncate_webpage_content(
        #     webpage_content, 
        #     configurable.max_webpage_content_length, 
        #     topic_list
        # )
        
        # # Log truncation if it occurred
        # if len(webpage_content) > len(truncated_content):
        #     logger.info(f"📄 Webpage content truncated: {len(webpage_content)} → {len(truncated_content)} characters")

        structured_llm = llm.with_structured_output(ExtractInformation)
        information_list = []
        should_continue = True
        for i, webpage_content_part in enumerate(webpage_content_parts, 1):
            logger.info(f"📄 Processing part {i}/{len(webpage_content_parts)}")
            webpage_content_part = "Part " + str(i) + "/" + str(len(webpage_content_parts)) + ":\n" + webpage_content_part.strip()
            formatted_prompt = extract_information_prompt.format(webpage_content=webpage_content_part,
                                                                original_question=get_user_question(state['messages']),
                                                                sub_question=state.get('current_sub_question', ''),
                                                                current_date=get_current_date())
            result = structured_llm.invoke(formatted_prompt)
            # Extend instead of append to flatten the list of lists into a single list
            information_list.extend(result.information_list)
            should_continue = result.should_continue
            if not should_continue:
                break
        return information_list

    reference_list = []
    visited_urls = []
    url_list = _get_url_list()
    logger.info(f"🌐 Visiting {len(url_list)} URLs: {url_list}")
    visit_count = 0
    visit_cache = state.get('browser_state', {}).get('visit_cache', {})
    for i, url in enumerate(url_list, 1):
        logger.info(f"📄 Processing webpage {i}/{len(url_list)}: {url}")
        visit_count += 1
        visited_urls.append(url)
    
        if url in visit_cache.keys():
            logger.info(f"📄 Cached visit for {url}")
            webpage_content = visit_cache[url]
        else:
            webpage_content = visit(url, api_key=JINA_API_KEY)
            visit_cache[url] = webpage_content
        webpage_content_parts = split_webpage_content(webpage_content, configurable.max_webpage_content_length)
        logger.info(f"📄 Split webpage content into {len(webpage_content_parts)} parts")
        information_list = _extract_information(webpage_content_parts)
        if information_list:
            reference = Reference(
                url=url,
                information_list=information_list)
            reference_list.append(reference)
            break
    browser_state = BrowserState(
        visit_count=visit_count,
        visited_urls=visited_urls,
        found_references=reference_list,
        visit_cache=visit_cache
    )
    logger.info(f"✅ Browser completed! Extracted information from {len(reference_list)} references")
    return {"browser_state": browser_state}

def thinker(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a thinker block that thinks about the sub-question and generates a thinking process to answer the sub-question.
    """
    logger.info("🤔 Starting thinking process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(model=configurable.thinker_model, api_key=OPENAI_API_KEY)
    structured_llm = llm.with_structured_output(Thinker)
    formatted_prompt = thinker_prompt.format(original_question=get_user_question(state['messages']),
                                            current_sub_question=state.get('current_sub_question', ''),
                                            current_summary=state.get('current_summary', ''),
                                            current_date=get_current_date())
    result = structured_llm.invoke(formatted_prompt)
    logger.info(f"🤔 Thinking process generated! Length: {len(result.thinking_process)} characters")

    # add the thinking process as a reference to the browser state
    reference = Reference(
        url="from thinking agent",
        information_list=[result.thinking_process]
    )
    browser_state = state.get('browser_state', {})
    browser_state['found_references'].append(reference)

    return {"browser_state": browser_state}

def summarizer(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general summarizer block that takes collected information from references and generates a comprehensive summary to answer the current question.
    """
    logger.info("📝 Starting summarization process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(
        model=configurable.summarizer_model,
        api_key=OPENAI_API_KEY,
    )
    
    def _summarize_information():
        structured_llm = llm.with_structured_output(Summarizer)
        formatted_prompt = summarizer_prompt.format(original_question=get_user_question(state['messages']),
                                                    current_sub_question=state.get('current_sub_question', ''),
                                                    current_summary=state.get('current_summary', ''),
                                                    list_of_information_and_their_source_url_links=json.dumps([ref.dict() for ref in state.get('browser_state', {}).get('found_references', [])], indent=4),
                                                    current_date=get_current_date())
        result = structured_llm.invoke(formatted_prompt)
        return result.summary
    
    logger.info("📊 Generating summary from collected information...")
    summary = _summarize_information()
    logger.info(f"✅ Summary completed! Generated {len(summary)} characters of summary")
    return {"current_summary": summary}


def verifier(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general verifier block that checks whether the current summary provides sufficient information to answer the question.
    """
    # Read variant from node config; default to "sub"
    variant = (config.get("configurable", {}) or {}).get("verifier_variant", "sub")
    logger.info(f"✅ Starting {variant} verification process...")
    
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(model=configurable.verifier_model, api_key=OPENAI_API_KEY)

    # Pick which question to verify against
    if variant == "final":
        question_text = get_user_question(state["messages"])
    else:
        question_text = state.get("current_sub_question") or get_user_question(state["messages"])

    structured_llm = llm.with_structured_output(Verifier)
    formatted_prompt = verifier_prompt.format(
        question=question_text,
        current_summary=state.get("current_summary") or "",
        current_date=get_current_date()
    )
    result = structured_llm.invoke(formatted_prompt)

    if variant == "final":
        logger.info(f"🔍 Final verification result: {'✅ PASSED' if result.can_answer_question else '❌ FAILED'}")
        return {"final_verified": result.can_answer_question}
    else:
        logger.info(f"🔍 Sub verification result: {'✅ PASSED' if result.can_answer_question else '❌ FAILED'}")
        current_iteration = state.get("current_sub_question_iteration", 1)
        return {"sub_verified": result.can_answer_question, "current_sub_question_iteration": current_iteration + 1}

def next_sub_question_writer(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general next sub-question writer block that generates new sub-questions when the current summary is insufficient to answer the main question.
    """
    logger.info("❓ Generating next sub-question...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(model=configurable.next_sub_question_writer_model, api_key=OPENAI_API_KEY)
    structured_llm = llm.with_structured_output(NextSubQuestionWriter)
    formatted_prompt = next_sub_question_writer_prompt.format(question=get_user_question(state['messages']),
                                                    current_summary=state.get('current_summary', ''),
                                                    current_date=get_current_date())
    result = structured_llm.invoke(formatted_prompt)
    logger.info(f"❓ New sub-question generated: {result.new_sub_question}")
    return {"current_sub_question": result.new_sub_question, "current_sub_question_iteration": 1}
    
def finalizer(state: OverallState, config: RunnableConfig) -> OverallState:
    """
    This is a general finalizer block that generates the final comprehensive answer based on the accumulated summary and research findings.
    """
    logger.info("🎯 Starting finalization process...")
    configurable = Configuration.from_runnable_config(config)
    llm = ChatOpenAI(model=configurable.finalizer_model, api_key=OPENAI_API_KEY)
    structured_llm = llm.with_structured_output(Finalizer)
    formatted_prompt = finalizer_prompt.format(question=get_user_question(state['messages']),
                                                    current_summary=state.get('current_summary', ''),
                                                    current_date=get_current_date())
    result = structured_llm.invoke(formatted_prompt)
    logger.info(f"🎯 Final answer generated! Length: {len(result.final_answer)} characters")
    return {"final_answer": result.final_answer}