"""Driving mentor agent for comprehensive safety assessment and improvement suggestions."""

import re
import json
from typing import TypedDict

from langgraph.graph import StateGraph, START, END

from .agent_prompt import driving_mentor_prompt, DrivingAssessment
from .scene_extraction import SceneExtractor
from ..llms import get_llm
from ...utils.settings import settings
from ...utils.log import logger


def _parse_driving_assessment_from_text(text_response: str) -> dict:
    """Parse DrivingAssessment from text when structured output is not available."""
    # Try to find JSON first
    json_match = re.search(r'\{[^}]*"safety_score"[^}]*\}', text_response, re.DOTALL)
    if json_match:
        try:
            parsed = json.loads(json_match.group())
            # Validate required fields
            required_fields = ['safety_score', 'overall_evaluation', 'strengths', 'weaknesses', 'improvement_advice', 'risk_level']
            if all(field in parsed for field in required_fields):
                return parsed
        except json.JSONDecodeError:
            pass
    
    # Fallback to text analysis
    text_lower = text_response.lower()
    
    # Extract safety score (1-10)
    safety_score = 5  # default
    score_patterns = [
        r'safety score[:\s]*(\d+)',
        r'score[:\s]*(\d+)(?:\s*/\s*10)?',
        r'rating[:\s]*(\d+)'
    ]
    for pattern in score_patterns:
        match = re.search(pattern, text_lower)
        if match:
            try:
                score = int(match.group(1))
                if 1 <= score <= 10:
                    safety_score = score
                    break
            except ValueError:
                continue
    
    # Extract risk level
    risk_level = "medium"  # default
    if any(word in text_lower for word in ['critical', 'severe', 'extremely dangerous']):
        risk_level = "critical"
    elif any(word in text_lower for word in ['high risk', 'dangerous', 'unsafe']):
        risk_level = "high"
    elif any(word in text_lower for word in ['low risk', 'safe', 'minor']):
        risk_level = "low"
    
    # Extract overall evaluation (find first substantial paragraph)
    evaluation = "The driving behavior shows mixed safety performance with areas for improvement."
    eval_patterns = [
        r'overall[^.]*?evaluation[:\s]*([^.]+(?:\.[^.]+){0,2})',
        r'assessment[:\s]*([^.]+(?:\.[^.]+){0,2})',
        r'evaluation[:\s]*([^.]+(?:\.[^.]+){0,2})'
    ]
    for pattern in eval_patterns:
        match = re.search(pattern, text_response, re.IGNORECASE | re.DOTALL)
        if match:
            extracted = match.group(1).strip()
            if len(extracted) > 20:  # Ensure meaningful content
                evaluation = extracted
                break
    
    # Extract strengths (look for bullet points or numbered lists)
    strengths = ["Maintained basic vehicle control"]  # default
    strength_section = re.search(r'strengths?[:\s]*(.+?)(?=weaknesses?|improvement|advice|$)', text_response, re.IGNORECASE | re.DOTALL)
    if strength_section:
        strengths_text = strength_section.group(1)
        # Extract bullet points or numbered items
        strength_items = re.findall(r'(?:[-*•]\s*|^\d+\.?\s*)([^-*•\n]+)', strengths_text, re.MULTILINE)
        if strength_items:
            strengths = [item.strip() for item in strength_items if len(item.strip()) > 5][:5]  # Max 5 items
    
    # Extract weaknesses
    weaknesses = ["Needs improvement in hazard awareness"]  # default
    weakness_section = re.search(r'weaknesses?[:\s]*(.+?)(?=improvement|advice|strengths?|$)', text_response, re.IGNORECASE | re.DOTALL)
    if weakness_section:
        weaknesses_text = weakness_section.group(1)
        weakness_items = re.findall(r'(?:[-*•]\s*|^\d+\.?\s*)([^-*•\n]+)', weaknesses_text, re.MULTILINE)
        if weakness_items:
            weaknesses = [item.strip() for item in weakness_items if len(item.strip()) > 5][:5]  # Max 5 items
    
    # Extract improvement advice
    advice = ["Practice defensive driving techniques"]  # default
    advice_section = re.search(r'(?:improvement|advice|recommendations?)[:\s]*(.+?)(?=strengths?|weaknesses?|$)', text_response, re.IGNORECASE | re.DOTALL)
    if advice_section:
        advice_text = advice_section.group(1)
        advice_items = re.findall(r'(?:[-*•]\s*|^\d+\.?\s*)([^-*•\n]+)', advice_text, re.MULTILINE)
        if advice_items:
            advice = [item.strip() for item in advice_items if len(item.strip()) > 5][:5]  # Max 5 items
    
    return {
        'safety_score': safety_score,
        'overall_evaluation': evaluation,
        'strengths': strengths,
        'weaknesses': weaknesses,
        'improvement_advice': advice,
        'risk_level': risk_level
    }


def _test_structured_output(model_id: str) -> bool:
    """Test if the model supports structured output."""
    if model_id.startswith("gateway:"):
        return False
    
    # For other models, assume they support structured output
    # OpenAI and Groq models are known to support it
    return True


# Initialize LLM and create driving assessor with gateway model support
llm = get_llm(settings.app.llm['main'])
main_model_id = settings.app.llm['main']

if _test_structured_output(main_model_id):
    # Use structured output for compatible models (OpenAI, Groq)
    driving_assessor = driving_mentor_prompt | llm.with_structured_output(DrivingAssessment).with_retry()
else:
    # Use text parsing fallback for gateway models
    driving_assessor = (driving_mentor_prompt | llm | _parse_driving_assessment_from_text).with_retry()


# Graph State
class GraphState(TypedDict):
    """State for the driving mentor graph."""
    annotation: str
    scenes: list[str]
    accident_results: list[dict]
    rule_results: list[dict]
    assessment: DrivingAssessment


def extract_scenes(state: GraphState) -> dict:
    """Extract simple scenes from complex annotation."""
    logger.debug("-------Extracting Scenes-------")
    annotation = state['annotation']
    
    extractor = SceneExtractor()
    scenes = extractor.extract(annotation)
    
    logger.debug(f"Extracted {len(scenes)} scenes for analysis")
    return {'scenes': scenes}


def analyze_accidents(state: GraphState) -> dict:
    """Analyze each scene for accident risks."""
    logger.debug("-------Analyzing Accident Risks-------")
    scenes = state['scenes']
    accident_results = []
    
    # Import agent only when needed to avoid database connection issues
    from .traffic_accident_retriever import traffic_accident_agent
    
    for i, scene in enumerate(scenes, 1):
        logger.debug(f"  Analyzing scene {i}: {scene[:50]}...")
        result = traffic_accident_agent.invoke({'scene': scene})
        accident_results.append({
            'scene': scene,
            'analysis': result['consequences']
        })
    
    logger.debug(f"Completed accident analysis for {len(scenes)} scenes")
    return {'accident_results': accident_results}


def check_rule_violations(state: GraphState) -> dict:
    """Check each scene for traffic rule violations."""
    logger.debug("-------Checking Rule Violations-------")
    scenes = state['scenes']
    rule_results = []
    
    # Import agent only when needed to avoid database connection issues
    from .traffic_rule_checker import traffic_rule_agent
    
    for i, scene in enumerate(scenes, 1):
        logger.debug(f"  Checking scene {i}: {scene[:50]}...")
        result = traffic_rule_agent.invoke({'query': scene})
        rule_results.append({
            'scene': scene,
            'analysis': result['result']
        })
    
    logger.debug(f"Completed rule checking for {len(scenes)} scenes")
    return {'rule_results': rule_results}


def generate_assessment(state: GraphState) -> dict:
    """Generate comprehensive driving assessment."""
    logger.debug("-------Generating Assessment-------")
    
    annotation = state['annotation']
    accident_results = state['accident_results']
    rule_results = state['rule_results']
    
    # Format results for the prompt
    accident_summary = "\n".join([
        f"Scene: {r['scene']}\n"
        f"Accident Risk: {r['analysis']['accident']} - {r['analysis']['consequence']}\n"
        for r in accident_results
    ])
    
    rule_summary = "\n".join([
        f"Scene: {r['scene']}\n"
        f"Rule Violation: {r['analysis']['violation']} - {r['analysis']['reason']}\n"
        for r in rule_results
    ])
    
    assessment = driving_assessor.invoke({
        'annotation': annotation,
        'accident_results': accident_summary,
        'rule_results': rule_summary
    })
    
    logger.debug(f"Generated assessment with safety score: {assessment['safety_score']}/10")
    return {'assessment': assessment}


# Build graph
graph = StateGraph(GraphState)
graph.add_node("extract_scenes", extract_scenes)
graph.add_node("analyze_accidents", analyze_accidents)
graph.add_node("check_rule_violations", check_rule_violations)
graph.add_node("generate_assessment", generate_assessment)

# Define workflow
graph.add_edge(START, "extract_scenes")
graph.add_edge("extract_scenes", "analyze_accidents")
graph.add_edge("extract_scenes", "check_rule_violations")
graph.add_edge(["analyze_accidents", "check_rule_violations"], "generate_assessment")
graph.add_edge("generate_assessment", END)

# Build agent
driving_mentor_agent = graph.compile()


class DrivingMentor:
    """Driving mentor agent for comprehensive safety assessment."""
    
    def __init__(self):
        """Initialize the driving mentor."""
        self.agent = driving_mentor_agent
    
    def assess_driving(self, annotation: str) -> DrivingAssessment:
        """Assess driving behavior from complex traffic annotation.
        
        Args:
            annotation (str): Complex traffic scene annotation from VideoAnnotator.
            
        Returns:
            DrivingAssessment: Comprehensive safety assessment with scores and advice.
        """
        result = self.agent.invoke({'annotation': annotation})
        return result['assessment']
    
    def assess_with_details(self, annotation: str) -> dict:
        """Assess driving with detailed intermediate results.
        
        Args:
            annotation (str): Complex traffic scene annotation.
            
        Returns:
            dict: Complete analysis including scenes, accident risks, rule violations, and assessment.
        """
        result = self.agent.invoke({'annotation': annotation})
        return {
            'annotation': annotation,
            'scenes': result['scenes'],
            'accident_analysis': result['accident_results'],
            'rule_analysis': result['rule_results'],
            'assessment': result['assessment']
        }
    
    def get_safety_summary(self, annotation: str) -> dict:
        """Get a concise safety summary.
        
        Args:
            annotation (str): Complex traffic scene annotation.
            
        Returns:
            dict: Concise summary with key metrics and recommendations.
        """
        assessment = self.assess_driving(annotation)
        
        return {
            'safety_score': assessment['safety_score'],
            'risk_level': assessment['risk_level'],
            'key_issues': assessment['weaknesses'][:3],  # Top 3 issues
            'top_advice': assessment['improvement_advice'][:3],  # Top 3 recommendations
            'overall_evaluation': assessment['overall_evaluation']
        }
