"""
Articulation VLM Critic Agent

This agent analyzes rendered articulation images with colored child links
to identify issues in joint configurations and movements.
"""

import os
import json
import logging
from typing import Dict, Any, List, Tuple, Optional
from pydantic import BaseModel
from agents.base_agent import BaseAgent
from utils.output_parser import OutputFormatError


class ArticulationFeedback(BaseModel):
    """Data model for articulation VLM feedback"""
    overall_assessment: str = ""
    specific_issues: List[str] = []
    improvement_suggestions: List[str] = []
    needs_improvement: bool = True
    confidence_score: float = 0.5
    joint_issues: Dict[str, Any] = {}  # Joint-specific issues

    @classmethod
    def extract_from_response(cls, text: str) -> "ArticulationFeedback":
        """
        Extract structured feedback from LLM response.

        Args:
            text: Raw LLM response

        Returns:
            ArticulationFeedback instance
        """
        # Try to find JSON in the response
        import re
        json_match = re.search(r'\{.*\}', text, re.DOTALL)

        if json_match:
            try:
                feedback_dict = json.loads(json_match.group())
                return cls(**feedback_dict)
            except (json.JSONDecodeError, TypeError) as e:
                logging.warning(f"Failed to parse JSON feedback: {e}")

        # Fallback: parse text format
        feedback = cls()

        # Extract sections
        sections = {
            'overall_assessment': r'Overall Assessment[:\s]*([^\n]+(?:\n(?![A-Z][a-z]+:).*)*)',
            'specific_issues': r'Specific Issues[:\s]*\n((?:[-•*]\s.*\n?)+)',
            'improvement_suggestions': r'Improvement Suggestions[:\s]*\n((?:[-•*]\s.*\n?)+)',
            'confidence_score': r'Confidence Score[:\s]*([\d.]+)',
            'needs_improvement': r'Needs Improvement[:\s]*(true|false|yes|no)',
        }

        for key, pattern in sections.items():
            match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
            if match:
                if key in ['specific_issues', 'improvement_suggestions']:
                    # Parse bullet points
                    items = re.findall(r'[-•*]\s*(.+)', match.group(1))
                    setattr(feedback, key, items)
                elif key == 'confidence_score':
                    try:
                        score = float(match.group(1))
                        feedback.confidence_score = max(0.0, min(1.0, score))
                    except ValueError:
                        pass
                elif key == 'needs_improvement':
                    value = match.group(1).lower()
                    feedback.needs_improvement = value in ['true', 'yes']
                else:
                    setattr(feedback, key, match.group(1).strip())

        return feedback


class ArticulationVLMCritic(BaseAgent):
    """
    VLM agent for analyzing articulated object renders with colored child links.

    Identifies issues in joint movements, collisions, and articulation quality
    by analyzing colored parts across different joint states.
    """

    def __init__(self, config_manager):
        """
        Initialize the Articulation VLM Critic.

        Args:
            config_manager: Configuration manager instance
        """
        self.logger = logging.getLogger(self.__class__.__name__)

        # Initialize with VLM critic agent type
        super().__init__(config_manager, 'articulation_vlm_critic')

    def _load_system_prompt(self) -> str:
        """Load system prompt for articulation analysis."""
        try:
            from prompt.articulation_critic import system_prompt
            return system_prompt
        except ImportError as e:
            self.logger.error(f"Failed to load articulation critic prompt: {e}")
            # Fallback prompt
            return """You are an expert robotics engineer specializing in articulated mechanisms and joint analysis.
            Your task is to analyze rendered images of articulated objects where child links of movable joints
            are colored to show their movement. Fixed joints have gray children.

            Analyze the object across different joint states to identify:
            1. Collision or penetration between moving parts
            2. Parts separating that should stay connected
            3. Unrealistic or excessive motion ranges
            4. Incorrect rotation axes or directions
            5. Missing articulation where expected
            6. Unnecessary or redundant joints

            Focus on the colored parts as they represent movable joint children.
            Provide specific, actionable feedback referencing the colored parts."""

    def _format_user_prompt(self, input_data: Dict[str, Any]) -> str:
        """
        Format user prompt with images and color mapping.

        Args:
            input_data: Dictionary containing image_paths, color_mapping, etc.

        Returns:
            Formatted user prompt
        """
        image_paths = input_data.get('image_paths', [])
        color_mapping = input_data.get('color_mapping', {})
        articulation_json = input_data.get('articulation_json', [])
        object_description = input_data.get('description', '')

        # Format color mapping for clarity
        color_descriptions = []
        for joint_name, info in color_mapping.items():
            color_name = info.get('color_name', 'UNKNOWN')
            child_link = info.get('child_link', '')
            joint_type = info.get('joint_type', '')
            parent_link = info.get('parent_link', '')

            desc = f"- {color_name} part ({child_link}): {joint_name} ({joint_type} joint, parent: {parent_link})"
            color_descriptions.append(desc)

        # Group images by state
        state_images = {
            'initial': [],
            'moved': []
        }

        for path in image_paths:
            filename = os.path.basename(path)
            if 'initial' in filename:
                state_images['initial'].append(filename)
            elif 'moved' in filename:
                state_images['moved'].append(filename)

        prompt = f"""Please analyze this articulated object: {object_description}

COLOR-TO-JOINT MAPPING (colored parts show movable joints):
{chr(10).join(color_descriptions)}

Gray/default colored parts are fixed and should not move between states.

IMAGES PROVIDED:
- Initial state (0% position): {', '.join(state_images['initial'])}
- Moved state (75% position): {', '.join(state_images['moved'])}

CRITICAL ANALYSIS TASKS:

1. **Collision Detection**:
   - Check if any colored part penetrates or collides with other parts when moved
   - Look for colored parts going through walls, bases, or other components
   - Example: "RED blade_assembly penetrates the jar wall in moved state"

2. **Connection Integrity**:
   - Verify colored parts stay properly attached to their parents
   - Check for gaps or separations that shouldn't exist
   - Example: "BLUE spout separates from lid body when rotated"

3. **Motion Range Validation**:
   - Assess if the movement range is realistic for the joint type
   - Check if joints move too far or in wrong directions
   - Example: "GREEN drawer extends beyond reasonable limits"

4. **Joint Axis Verification**:
   - Confirm colored parts rotate/translate along correct axes
   - Identify if movement direction matches expected behavior
   - Example: "YELLOW door rotates on wrong axis"

5. **Missing or Redundant Articulation**:
   - Identify parts that should move but don't (missing joints)
   - Find unnecessary articulation (redundant joints)
   - Note if fixed parts should be movable or vice versa

CONFIDENCE SCORING GUIDELINES:
- 0.9-1.0: No articulation issues, all joints work perfectly
- 0.7-0.8: Minor issues like slightly excessive ranges
- 0.5-0.6: Moderate issues like wrong joint types or axes
- 0.3-0.4: Major issues like collisions or separated parts
- 0.1-0.2: Critical failures like multiple collisions or completely wrong articulation

Please provide your analysis in the following format:

Overall Assessment: [Brief summary of articulation quality]

Specific Issues:
- [Issue 1 with colored part reference]
- [Issue 2 with colored part reference]
- [etc.]

Improvement Suggestions:
- [Specific fix for issue 1]
- [Specific fix for issue 2]
- [etc.]

Needs Improvement: [true/false]
Confidence Score: [0.0-1.0]"""

        return prompt

    def parse_response(self, response: str) -> ArticulationFeedback:
        """
        Parse LLM response into structured feedback.

        Args:
            response: Raw LLM response

        Returns:
            ArticulationFeedback instance

        Raises:
            OutputFormatError: If response format is invalid
        """
        try:
            return ArticulationFeedback.extract_from_response(response)
        except Exception as e:
            self.logger.error(f"Failed to parse articulation VLM response: {e}")
            # Return basic feedback as fallback
            return ArticulationFeedback(
                overall_assessment="Failed to parse VLM response",
                specific_issues=["Response parsing error"],
                improvement_suggestions=["Retry analysis"],
                needs_improvement=True,
                confidence_score=0.5
            )

    def _prepare_input_data(self, image_paths: List[str],
                           color_mapping: Dict[str, Any],
                           articulation_json: List[Dict[str, Any]] = None,
                           description: str = "",
                           **kwargs) -> Dict[str, Any]:
        """
        Prepare input data for VLM analysis.

        Args:
            image_paths: List of rendered image paths
            color_mapping: Joint-to-color mapping
            articulation_json: Current articulation specification
            description: Object description
            **kwargs: Additional arguments

        Returns:
            Dictionary of input data
        """
        return {
            'image_paths': image_paths,
            'color_mapping': color_mapping,
            'articulation_json': articulation_json or [],
            'description': description
        }

    def save_output(self, result: ArticulationFeedback, output_folder: str,
                   metrics: Dict[str, Any] = None):
        """
        Save VLM feedback to file.

        Args:
            result: ArticulationFeedback result
            output_folder: Directory to save output
            metrics: Generation metrics (optional)
        """
        if not result:
            return

        # Save structured feedback
        feedback_path = os.path.join(output_folder, "vlm_feedback.json")
        feedback_dict = {
            'overall_assessment': result.overall_assessment,
            'specific_issues': result.specific_issues,
            'improvement_suggestions': result.improvement_suggestions,
            'needs_improvement': result.needs_improvement,
            'confidence_score': result.confidence_score,
            'joint_issues': result.joint_issues
        }

        with open(feedback_path, 'w', encoding='utf-8') as f:
            json.dump(feedback_dict, f, indent=2, ensure_ascii=False)

        # Save human-readable version
        readable_path = os.path.join(output_folder, "vlm_feedback.txt")
        with open(readable_path, 'w', encoding='utf-8') as f:
            f.write(f"Articulation VLM Analysis Results\n")
            f.write(f"{'=' * 50}\n\n")
            f.write(f"Overall Assessment:\n{result.overall_assessment}\n\n")
            f.write(f"Confidence Score: {result.confidence_score:.2f}\n\n")
            f.write(f"Specific Issues:\n")
            for issue in result.specific_issues:
                f.write(f"  - {issue}\n")
            f.write(f"\nImprovement Suggestions:\n")
            for suggestion in result.improvement_suggestions:
                f.write(f"  - {suggestion}\n")
            f.write(f"\nNeeds Improvement: {result.needs_improvement}\n")

        self.logger.info(f"Saved VLM feedback to {output_folder}")

    def analyze_articulation(self, image_paths: List[str],
                            color_mapping: Dict[str, Any],
                            articulation_json: List[Dict[str, Any]] = None,
                            description: str = "",
                            output_folder: Optional[str] = None,
                            **kwargs) -> Tuple[ArticulationFeedback, bool, Dict[str, Any], str]:
        """
        Analyze articulated object renders with VLM.

        Args:
            image_paths: List of rendered image paths
            color_mapping: Joint-to-color mapping
            articulation_json: Current articulation specification
            description: Object description
            output_folder: Optional output folder for saving results
            **kwargs: Additional arguments

        Returns:
            Tuple of (feedback, success, metrics, raw_response)
        """
        # Prepare input
        input_data = self._prepare_input_data(
            image_paths, color_mapping, articulation_json, description, **kwargs
        )

        # Generate VLM analysis
        result, success, metrics, raw_response = self.generate(**input_data, **kwargs)

        # Save output if folder provided
        if success and result and output_folder:
            self.save_output(result, output_folder, metrics)

        return result, success, metrics, raw_response

    def needs_improvement(self, feedback: ArticulationFeedback) -> bool:
        """
        Determine if articulation needs improvement based on feedback.

        Args:
            feedback: ArticulationFeedback from VLM analysis

        Returns:
            True if improvement is needed
        """
        if not feedback:
            return True

        # Check explicit flag
        if feedback.needs_improvement:
            return True

        # Check confidence score
        if feedback.confidence_score < 0.7:
            return True

        # Check for critical issues
        critical_keywords = ['collision', 'penetrate', 'separate', 'disconnect', 'wrong axis']
        for issue in feedback.specific_issues:
            if any(keyword in issue.lower() for keyword in critical_keywords):
                return True

        return False