import os
import re
import json

from ...utils.settings import settings
from ..llms import get_llm
from ...utils.log import logger
from ...db.neo4j.neo4j_retriever import Neo4jRetriever
from .agent_prompt import Yes_No, accident_retrieval_grader_prompt, accident_analysis_prompt, Accident_Result

from langgraph.graph import StateGraph, START, END
from typing import TypedDict
from langchain_core.documents import Document


# Helper functions for text parsing fallback
def _parse_yes_no_from_text(text_response: str) -> dict:
    """Parse Yes/No response from text when structured output is not available."""
    text_lower = text_response.lower()
    
    # Try to find JSON first
    json_match = re.search(r'\{[^}]*"ans"[^}]*\}', text_response)
    if json_match:
        try:
            parsed = json.loads(json_match.group())
            if 'ans' in parsed and parsed['ans'] in ['yes', 'no']:
                return parsed
        except json.JSONDecodeError:
            pass
    
    # Fallback to text analysis
    # Look for clear yes/no indicators
    if any(word in text_lower for word in ['yes', 'relevant', 'applicable', 'similar', 'matches']):
        if not any(word in text_lower for word in ['not relevant', 'not applicable', 'not similar', 'no']):
            return {'ans': 'yes'}
    
    if any(word in text_lower for word in ['no', 'not relevant', 'not applicable', 'not similar', 'irrelevant']):
        return {'ans': 'no'}
    
    # Default to 'no' if unclear (conservative approach for retrieval grading)
    return {'ans': 'no'}


def _parse_accident_result_from_text(text_response: str) -> dict:
    """Parse Accident_Result from text when structured output is not available."""
    # Try to find JSON first
    json_match = re.search(r'\{[^}]*"accident"[^}]*\}', text_response, re.DOTALL)
    if json_match:
        try:
            parsed = json.loads(json_match.group())
            if 'accident' in parsed and 'consequence' in parsed:
                if parsed['accident'] in ['found', 'not_found']:
                    return parsed
        except json.JSONDecodeError:
            pass
    
    # Fallback to text analysis
    text_lower = text_response.lower()
    
    # Look for accident indicators
    accident_found = False
    if any(phrase in text_lower for phrase in [
        'accident found', 'similar accident', 'accidents found', 'accident risk',
        'collision', 'crash', 'accident occurred', 'similar situation',
        'historical accident', 'accident scenario'
    ]):
        if not any(phrase in text_lower for phrase in [
            'no accident', 'not found', 'no similar', 'no historical'
        ]):
            accident_found = True
    
    if accident_found:
        # Extract consequence information
        consequence = "Similar traffic accidents found in historical data suggesting potential collision risk."
        
        # Try to find more specific consequence in the response
        consequence_patterns = [
            r'consequence[s]?[:\-\s]+([^.]+)',
            r'result[s]?[:\-\s]+([^.]+)', 
            r'risk[s]?[:\-\s]+([^.]+)',
            r'potential[:\-\s]+([^.]+)'
        ]
        
        for pattern in consequence_patterns:
            match = re.search(pattern, text_response, re.IGNORECASE | re.DOTALL)
            if match:
                extracted = match.group(1).strip()
                if len(extracted) > 10:  # Ensure we have meaningful content
                    consequence = extracted
                    break
        
        return {
            'accident': 'found',
            'consequence': consequence
        }
    else:
        return {
            'accident': 'not_found',
            'consequence': 'No traffic accident found'
        }


def _test_structured_output(model_id: str) -> bool:
    """Test if the model supports structured output."""
    if model_id.startswith("gateway:"):
        return False
    
    # For other models, assume they support structured output
    # OpenAI and Groq models are known to support it
    return True

# Retriever
llm = get_llm(settings.app.llm['fast'])
retriever = Neo4jRetriever(llm=llm)

# Retriever Grader - with gateway model support
fast_model_id = settings.app.llm['fast']
if _test_structured_output(fast_model_id):
    # Use structured output for compatible models (OpenAI, Groq)
    retrieval_grader = accident_retrieval_grader_prompt | llm.with_structured_output(Yes_No).with_retry()
else:
    # Use text parsing fallback for gateway models
    retrieval_grader = (accident_retrieval_grader_prompt | llm | _parse_yes_no_from_text).with_retry()

# Accident Scene Analysis - with gateway model support  
llm_main = get_llm(settings.app.llm['main'])
main_model_id = settings.app.llm['main']
if _test_structured_output(main_model_id):
    # Use structured output for compatible models (OpenAI, Groq)
    accident_analyst = accident_analysis_prompt | llm_main.with_structured_output(Accident_Result).with_retry()
else:
    # Use text parsing fallback for gateway models
    accident_analyst = (accident_analysis_prompt | llm_main | _parse_accident_result_from_text).with_retry()

# Graph
class GraphState(TypedDict):
    """
    State for the traffic accident retriever graph.
    """
    scene: str
    retrieved_accidents: list[Document]
    consequences: Accident_Result

def retrieve(state: dict) -> dict:
    """
    Retrieve traffic accidents based on the traffic scene.
    """
    logger.debug("-------Retrieving-------")
    query = state['scene']
    accidents = retriever.invoke(query)
    return {'retrieved_accidents': accidents}

def grade_retrieval(state: dict) -> dict:
    logger.debug("-------Grading Retrieval-------")
    query = state['scene']
    retrieved_accidents = state['retrieved_accidents']
    for accident in retrieved_accidents:
        accident_text = accident.page_content
        grade = retrieval_grader.invoke({"traffic_scene": query, "retrieved_accident": accident_text})
        if grade['ans'] == 'no':
            logger.debug(f"  >>> Accident is not relevant: {accident_text[:100]}...")
            retrieved_accidents.remove(accident)
        else:
            logger.debug(f"  >>> Accident is relevant...")
    return {'retrieved_accidents': retrieved_accidents}

def analyze_accidents(state: dict) -> dict:
    """
    Analyze the retrieved accidents to determine possible consequences.
    """
    logger.debug("-------Analyzing Accidents-------")
    retrieved_accidents = state['retrieved_accidents']
    formatted_accidents = "\n\n".join([f"{accident.page_content}" for accident in retrieved_accidents])
    result = accident_analyst.invoke({'traffic_scene': state['scene'], 'historical_traffic_accidents': formatted_accidents})
    return {'consequences': result}

# build graph
graph = StateGraph(GraphState)
graph.add_node("retrieve", retrieve)
graph.add_node("grade_retrieval", grade_retrieval)
graph.add_node("analyze_accidents", analyze_accidents)

graph.add_edge(START, "retrieve")
graph.add_edge("retrieve", "grade_retrieval")
graph.add_edge("grade_retrieval", "analyze_accidents")
graph.add_edge("analyze_accidents", END)

# build agent
traffic_accident_agent = graph.compile()