from pymilvus import MilvusClient
from pathlib import Path
import re
import json

from ...db.milvus.milvus_db import MilvusDB
from ...db.milvus.milvus_retriever import MilvusRetriever
from ...utils.log import logger
from ...utils.settings import settings
from ...llm.llms import get_embedding, get_llm

# Retriever
emb_model = get_embedding(settings.app.embedding['main'])
emb_func = emb_model.embed_query
collection_name = settings.app.setting['traffic_rule_id']
db_file = Path(__file__).parent.parent.parent.parent / settings.app.path['milvus_db']
db_client = MilvusClient(db_file.as_posix())
db = MilvusDB(db_client, emb_func, collection_name)
retriever = MilvusRetriever(milvus=db, text_key="text", top_k=settings.app.setting['retrieved_traffic_rule_top_k'])

# Helper functions for text parsing fallback
def _parse_yes_no_from_text(text_response: str) -> dict:
    """Parse Yes/No response from text when structured output is not available."""
    text_lower = text_response.lower()
    
    # Try to find JSON first
    json_match = re.search(r'\{[^}]*"ans"[^}]*\}', text_response)
    if json_match:
        try:
            parsed = json.loads(json_match.group())
            if 'ans' in parsed and parsed['ans'] in ['yes', 'no']:
                return parsed
        except json.JSONDecodeError:
            pass
    
    # Fallback to text analysis
    # Look for clear yes/no indicators
    if any(word in text_lower for word in ['yes', 'relevant', 'applicable', 'matches']):
        if not any(word in text_lower for word in ['not relevant', 'not applicable', 'no']):
            return {'ans': 'yes'}
    
    if any(word in text_lower for word in ['no', 'not relevant', 'not applicable', 'irrelevant']):
        return {'ans': 'no'}
    
    # Default to 'no' if unclear (conservative approach for retrieval grading)
    return {'ans': 'no'}


def _parse_rule_result_from_text(text_response: str) -> dict:
    """Parse Rule_Result from text when structured output is not available."""
    # Try to find JSON first
    json_match = re.search(r'\{[^}]*"violation"[^}]*\}', text_response, re.DOTALL)
    if json_match:
        try:
            parsed = json.loads(json_match.group())
            if 'violation' in parsed and 'reason' in parsed:
                if parsed['violation'] in ['found', 'not_found']:
                    return parsed
        except json.JSONDecodeError:
            pass
    
    # Fallback to text analysis
    text_lower = text_response.lower()
    
    # Look for violation indicators
    violation_found = False
    if any(phrase in text_lower for phrase in [
        'violation found', 'violates', 'breaking', 'illegal', 'against the rule',
        'traffic violation', 'rule violation', 'not following', 'disobeying'
    ]):
        violation_found = True
    
    # Look for no violation indicators
    if any(phrase in text_lower for phrase in [
        'no violation', 'not violating', 'following the rules', 'legal', 
        'compliant', 'no traffic rule violation', 'within the rules'
    ]):
        violation_found = False
    
    # Extract reason - try to find the main explanation
    reason_text = text_response.strip()
    
    # Clean up the reason text - remove JSON-like formatting if present
    reason_text = re.sub(r'^[{\[\'"]*', '', reason_text)
    reason_text = re.sub(r'[}\]\'"]*$', '', reason_text)
    
    # If text is too long, try to extract the key sentence
    if len(reason_text) > 200:
        sentences = re.split(r'[.!?]+', reason_text)
        # Find the most informative sentence
        for sentence in sentences:
            if len(sentence.strip()) > 10 and any(word in sentence.lower() for word in 
                ['violation', 'rule', 'traffic', 'legal', 'illegal', 'because', 'since']):
                reason_text = sentence.strip()
                break
    
    return {
        'violation': 'found' if violation_found else 'not_found',
        'reason': reason_text if reason_text else ('Traffic rule violation found' if violation_found else 'No traffic rule violation found')
    }


# Retriever Grader and Rule Verifier with Gateway Model Support
from .agent_prompt import rule_retrieval_grader_prompt, Yes_No, Rule_Result, rule_verifier_prompt

def _supports_structured_output(model_id: str) -> bool:
    """Check if model supports structured output (json_schema)."""
    return not model_id.startswith("gateway:")

def _create_retrieval_grader(llm, model_id: str):
    """Create retrieval grader with structured output if supported."""
    if _supports_structured_output(model_id):
        try:
            return rule_retrieval_grader_prompt | llm.with_structured_output(Yes_No).with_retry()
        except Exception:
            # Fallback if structured output fails
            pass
    
    # Create text-based chain for gateway models
    return rule_retrieval_grader_prompt | llm.with_retry()

def _create_rule_verifier(llm, model_id: str):
    """Create rule verifier with structured output if supported."""
    if _supports_structured_output(model_id):
        try:
            return rule_verifier_prompt | llm.with_structured_output(Rule_Result).with_retry()
        except Exception:
            # Fallback if structured output fails
            pass
    
    # Create text-based chain for gateway models
    return rule_verifier_prompt | llm.with_retry()


# Initialize LLM and chains
llm = get_llm(settings.app.llm['fast'])
model_id = settings.app.llm['fast']
uses_structured_output = _supports_structured_output(model_id)

retrieval_grader = _create_retrieval_grader(llm, model_id)
rule_verifier = _create_rule_verifier(llm, model_id)

# Graph
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
from langchain_core.documents import Document

class GraphState(TypedDict):
    query: str
    retrieved_traffic_rules: list[str]
    result: Rule_Result

# Nodes
def retrieve(state: dict) -> dict:
    logger.debug("-------Retrieving-------")
    query = state['query']
    rules = retriever.invoke(query)
    return {'retrieved_traffic_rules': rules}

def grade_retrieval(state: dict) -> dict:
    logger.debug("-------Grading Retrieval-------")
    query = state['query']
    retrieved_traffic_rules = state['retrieved_traffic_rules']
    for rule in retrieved_traffic_rules:
        rule_text = rule.page_content
        result = retrieval_grader.invoke({"query": query, "retrieved_traffic_rule": rule_text})
        
        # Handle both structured and unstructured output
        if uses_structured_output:
            grade = result
        else:
            # Parse text response
            if hasattr(result, 'content'):
                grade = _parse_yes_no_from_text(result.content)
            elif isinstance(result, str):
                grade = _parse_yes_no_from_text(result)
            else:
                grade = _parse_yes_no_from_text(str(result))
        
        if grade['ans'] == 'no':
            logger.debug(f"  >>> Rule is not relevant: {rule_text[:100]}...")
            retrieved_traffic_rules.remove(rule)
        else:
            logger.debug(f"  >>> Rule is relevant...")
    return {'retrieved_traffic_rules': retrieved_traffic_rules}

def format_rules(doc: Document) -> str:
    return "\n\n".join([f"{rule.page_content}" for rule in doc])

def verify_rule(state: dict) -> dict:
    logger.debug("-------Verifying Rules-------")
    query = state['query']
    retrieved_traffic_rules = state['retrieved_traffic_rules']
    formatted_rules = format_rules(retrieved_traffic_rules)
    result = rule_verifier.invoke({'query': query, 'relevant_traffic_rules': formatted_rules})
    
    # Handle both structured and unstructured output
    if uses_structured_output:
        final_result = result
    else:
        # Parse text response
        if hasattr(result, 'content'):
            final_result = _parse_rule_result_from_text(result.content)
        elif isinstance(result, str):
            final_result = _parse_rule_result_from_text(result)
        else:
            final_result = _parse_rule_result_from_text(str(result))
    
    return {'result': final_result}

# build graph
graph = StateGraph(GraphState)
graph.add_node("retrieve", retrieve)
graph.add_node("grade_retrieval", grade_retrieval)
graph.add_node("rule_verifier", verify_rule)

graph.add_edge(START, "retrieve")
graph.add_edge("retrieve", "grade_retrieval")
graph.add_edge("grade_retrieval", "rule_verifier")
graph.add_edge("rule_verifier", END)

# build agent
traffic_rule_agent = graph.compile()


