"""
Articulation Fixer Agent

This agent fixes articulation issues identified by the VLM critic,
adjusting joint configurations based on colored child link feedback.
"""

import os
import re
import json
import logging
import numpy as np
from typing import Dict, Any, List, Tuple, Optional
from agents.base_agent import BaseAgent
from utils.output_parser import OutputFormatError


class ImprovedArticulation:
    """Container for improved articulation specification"""

    def __init__(self, articulation_json: List[Dict[str, Any]]):
        self.articulation_json = articulation_json

    @classmethod
    def extract_from_response(cls, text: str) -> "ImprovedArticulation":
        """
        Extract improved articulation from LLM response.

        Args:
            text: Raw LLM response

        Returns:
            ImprovedArticulation instance
        """
        import re

        # Clean up markdown formatting
        text = text.strip()
        if text.startswith('```json'):
            text = text[7:]
        if text.startswith('```'):
            text = text[3:]
        if text.endswith('```'):
            text = text[:-3]

        # Try to find JSON array in response
        array_matches = re.findall(r'\[.*?\]', text, re.DOTALL)

        for match in array_matches:
            # Clean up the JSON
            cleaned = re.sub(r'//.*', '', match)  # Remove comments
            cleaned = re.sub(r',\s*([}\]])', r'\1', cleaned)  # Remove trailing commas

            try:
                articulation = json.loads(cleaned)
                # Validate it's a proper articulation specification
                if (isinstance(articulation, list) and len(articulation) > 0 and
                    all(isinstance(item, dict) for item in articulation) and
                    all('joint_name' in item and 'parent' in item and 'child' in item and 'type' in item for item in articulation)):
                    return cls(articulation)
            except json.JSONDecodeError:
                continue

        # Fallback: try to find individual JSON objects and collect them
        json_matches = re.findall(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text, re.DOTALL)

        if json_matches:
            joints = []
            for match in json_matches:
                cleaned = re.sub(r'//.*', '', match)
                cleaned = re.sub(r',\s*([}\]])', r'\1', cleaned)
                try:
                    joint = json.loads(cleaned)
                    # Check all required fields are present
                    if (isinstance(joint, dict) and
                        'joint_name' in joint and 'parent' in joint and
                        'child' in joint and 'type' in joint):
                        joints.append(joint)
                except json.JSONDecodeError:
                    continue

            if joints:
                return cls(joints)

        # If no valid JSON found, raise error with more context
        raise OutputFormatError(
            f"No valid articulation JSON found in response.\n"
            f"Response preview: {text[:500]}..."
        )


class ArticulationFixer(BaseAgent):
    """
    Agent for fixing articulation issues based on VLM feedback.

    Takes colored child link feedback and adjusts joint configurations
    to resolve collisions, separations, and motion range issues.
    """

    def __init__(self, config_manager):
        """
        Initialize the Articulation Fixer Agent.

        Args:
            config_manager: Configuration manager instance
        """
        self.logger = logging.getLogger(self.__class__.__name__)

        # Initialize with articulation_fixer agent type
        super().__init__(config_manager, 'articulation_fixer')

    def _load_system_prompt(self) -> str:
        """Load system prompt for articulation fixing."""
        try:
            from prompt.articulation_fixer import system_prompt
            return system_prompt
        except ImportError as e:
            self.logger.error(f"Failed to load articulation fixer prompt: {e}")
            # Fallback prompt
            return """You are an expert in robotic articulation and joint configuration.
            Your task is to fix articulation issues identified in VLM feedback by adjusting
            joint specifications. Focus on resolving collisions, fixing separations,
            correcting axes, and adjusting motion ranges.

            Output the complete corrected articulation.json as a JSON array."""

    def _format_user_prompt(self, input_data: Dict[str, Any]) -> str:
        """
        Format user prompt with articulation and feedback.

        Args:
            input_data: Dictionary containing current articulation and VLM feedback

        Returns:
            Formatted user prompt
        """
        articulation_json = input_data.get('articulation_json', [])
        vlm_feedback = input_data.get('vlm_feedback', {})
        color_mapping = input_data.get('color_mapping', {})
        description = input_data.get('description', '')

        # Format color mapping for reference
        color_info = "\n**COLOR-TO-JOINT REFERENCE:**\n"
        for joint_name, info in color_mapping.items():
            color_name = info.get('color_name', 'UNKNOWN')
            child_link = info.get('child_link', '')
            joint_type = info.get('joint_type', '')
            color_info += f"- {color_name} part ({child_link}) = {joint_name} ({joint_type})\n"

        # Format VLM feedback
        feedback_text = f"""
VLM Analysis Results:
{color_info}

Overall Assessment: {vlm_feedback.get('overall_assessment', 'No assessment')}

Specific Issues (MUST FIX):
{chr(10).join(['- ' + issue for issue in vlm_feedback.get('specific_issues', [])])}

Improvement Suggestions (IMPLEMENT):
{chr(10).join(['- ' + suggestion for suggestion in vlm_feedback.get('improvement_suggestions', [])])}

Confidence Score: {vlm_feedback.get('confidence_score', 0.0)}
"""

        # Format current articulation
        current_articulation = json.dumps(articulation_json, indent=2)

        prompt = f"""Fix the articulation issues identified by the VLM critic.

Object Description: {description}

{feedback_text}

Current Articulation JSON:
```json
{current_articulation}
```

CRITICAL FIXING GUIDELINES:

1. **Map Colors to Joints**: Use the COLOR-TO-JOINT REFERENCE to understand which joint each colored part represents.

2. **Fix Collision Issues (HIGHEST PRIORITY)**:
   - If a colored part penetrates another part, adjust:
     * Joint limits (reduce upper/lower bounds)
     * Joint origin (shift to proper pivot point)
   - Example: "RED blade penetrates jar" → Reduce blade joint upper limit

3. **Fix Separation Issues**:
   - If colored parts separate from parents, adjust:
     * Joint origin xyz values
     * Parent-child relationships
   - Example: "BLUE spout separates from lid" → Correct spout joint origin

4. **Fix Wrong Axes**:
   - If rotation/translation is in wrong direction:
     * Change axis vector (e.g., [1,0,0] to [0,1,0])
   - Example: "GREEN drawer slides vertically" → Change axis to [1,0,0]

5. **Fix Excessive Ranges**:
   - If movement range is unrealistic:
     * Adjust limit.lower and limit.upper
   - Example: "Door opens 270 degrees" → Set upper to 1.57 (90 degrees)

6. **Maintain Existing Correct Joints**:
   - Only modify joints mentioned in the issues
   - Keep all other joints unchanged

OUTPUT REQUIREMENTS:
- Provide the COMPLETE fixed articulation as a JSON array
- Include ALL joints (both modified and unmodified)
- Ensure proper JSON formatting with no comments
- Each joint must have: joint_name, parent, child, type
- Add origin and limits where needed for movable joints
- Output ONLY the JSON array, with no extra text before or after
- Do NOT include markdown code blocks or any explanations

Return the corrected articulation specification as a valid JSON array:"""

        return prompt

    def parse_response(self, response: str) -> ImprovedArticulation:
        """
        Parse LLM response into improved articulation.

        Args:
            response: Raw LLM response

        Returns:
            ImprovedArticulation instance

        Raises:
            OutputFormatError: If response format is invalid
        """
        try:
            return ImprovedArticulation.extract_from_response(response)
        except OutputFormatError as e:
            self.logger.error(f"Failed to parse articulation fixer response: {e}")
            raise

    def _prepare_input_data(self, articulation_json: List[Dict[str, Any]],
                           vlm_feedback: Dict[str, Any],
                           color_mapping: Dict[str, Any] = None,
                           description: str = "",
                           **kwargs) -> Dict[str, Any]:
        """
        Prepare input data for articulation fixing.

        Args:
            articulation_json: Current articulation specification
            vlm_feedback: Feedback from VLM critic
            color_mapping: Joint-to-color mapping
            description: Object description
            **kwargs: Additional arguments

        Returns:
            Dictionary of input data
        """
        return {
            'articulation_json': articulation_json,
            'vlm_feedback': vlm_feedback,
            'color_mapping': color_mapping or {},
            'description': description
        }

    def save_output(self, result: ImprovedArticulation, output_folder: str,
                   metrics: Dict[str, Any] = None):
        """
        Save improved articulation to file.

        Args:
            result: ImprovedArticulation result
            output_folder: Directory to save output
            metrics: Generation metrics (optional)
        """
        if not result or not result.articulation_json:
            return

        # Save improved articulation
        improved_path = os.path.join(output_folder, "improved_articulation.json")
        with open(improved_path, 'w', encoding='utf-8') as f:
            json.dump(result.articulation_json, f, indent=2, ensure_ascii=False)

        self.logger.info(f"Saved improved articulation to {improved_path}")

        # Save a summary of changes
        self._save_change_summary(result.articulation_json, output_folder)

    def _save_change_summary(self, improved_articulation: List[Dict[str, Any]],
                            output_folder: str):
        """
        Save a summary of articulation changes.

        Args:
            improved_articulation: Fixed articulation specification
            output_folder: Output directory
        """
        summary_path = os.path.join(output_folder, "changes_summary.txt")

        with open(summary_path, 'w', encoding='utf-8') as f:
            f.write("Articulation Improvement Summary\n")
            f.write("=" * 50 + "\n\n")

            f.write("Modified Joints:\n")
            for joint in improved_articulation:
                joint_name = joint.get('joint_name', 'unnamed')
                joint_type = joint.get('type', 'unknown')

                f.write(f"\n{joint_name} ({joint_type}):\n")
                f.write(f"  Parent: {joint.get('parent', 'N/A')}\n")
                f.write(f"  Child: {joint.get('child', 'N/A')}\n")

                if 'axis' in joint:
                    f.write(f"  Axis: {joint['axis']}\n")

                if 'origin' in joint:
                    origin = joint['origin']
                    if 'xyz' in origin:
                        f.write(f"  Origin XYZ: {origin['xyz']}\n")
                    if 'rpy' in origin:
                        f.write(f"  Origin RPY: {origin['rpy']}\n")

                if 'limit' in joint:
                    limit = joint['limit']
                    f.write(f"  Limits: [{limit.get('lower', 'N/A')}, {limit.get('upper', 'N/A')}]\n")

    def fix_articulation(self, articulation_json: List[Dict[str, Any]],
                        vlm_feedback: Dict[str, Any],
                        color_mapping: Dict[str, Any] = None,
                        description: str = "",
                        output_folder: Optional[str] = None,
                        **kwargs) -> Tuple[List[Dict[str, Any]], bool, Dict[str, Any], str]:
        """
        Fix articulation issues based on VLM feedback.

        Args:
            articulation_json: Current articulation specification
            vlm_feedback: Feedback from VLM critic
            color_mapping: Joint-to-color mapping
            description: Object description
            output_folder: Optional output folder for saving results
            **kwargs: Additional arguments

        Returns:
            Tuple of (fixed_articulation, success, metrics, raw_response)
        """
        # Prepare input data
        input_data = self._prepare_input_data(
            articulation_json, vlm_feedback, color_mapping, description, **kwargs
        )

        # Generate fix
        result, success, metrics, raw_response = self.generate(**input_data, **kwargs)

        # Save output if successful and folder provided
        if success and result and output_folder:
            self.save_output(result, output_folder, metrics)

        # Return the fixed articulation
        if success and result:
            return result.articulation_json, True, metrics, raw_response
        else:
            return articulation_json, False, metrics, raw_response

    def validate_articulation(self, articulation: List[Dict[str, Any]]) -> Tuple[bool, List[str]]:
        """
        Validate articulation specification for correctness.

        Args:
            articulation: Articulation specification to validate

        Returns:
            Tuple of (is_valid, list_of_errors)
        """
        errors = []

        if not isinstance(articulation, list):
            errors.append("Articulation must be a list of joint specifications")
            return False, errors

        required_fields = {'joint_name', 'parent', 'child', 'type'}
        valid_types = {'fixed', 'revolute', 'continuous', 'prismatic'}

        for idx, joint in enumerate(articulation):
            if not isinstance(joint, dict):
                errors.append(f"Joint {idx} is not a dictionary")
                continue

            # Check required fields
            missing = required_fields - set(joint.keys())
            if missing:
                errors.append(f"Joint {joint.get('joint_name', idx)} missing fields: {missing}")

            # Check joint type
            joint_type = joint.get('type')
            if joint_type not in valid_types:
                errors.append(f"Joint {joint.get('joint_name', idx)} has invalid type: {joint_type}")

            # Check movable joint requirements
            if joint_type in ['revolute', 'continuous', 'prismatic']:
                if 'axis' not in joint:
                    errors.append(f"Movable joint {joint.get('joint_name', idx)} missing axis")
                elif len(joint.get('axis', [])) != 3:
                    errors.append(f"Joint {joint.get('joint_name', idx)} axis must be 3D vector")

                # Check limits for limited joints
                if joint_type in ['revolute', 'prismatic']:
                    if 'limit' not in joint:
                        errors.append(f"Limited joint {joint.get('joint_name', idx)} missing limits")
                    else:
                        limit = joint.get('limit', {})
                        if 'lower' not in limit or 'upper' not in limit:
                            errors.append(f"Joint {joint.get('joint_name', idx)} missing lower/upper limits")

        return len(errors) == 0, errors