"""
VLM Critic Agent for analyzing 3D object designs

This agent uses Vision-Language Models to analyze rendered 3D objects and provide
detailed feedback about design issues, proportions, and improvements needed.
"""

import os
import json
import base64
import logging
from typing import Dict, Any, List, Optional, Tuple
from pydantic import BaseModel
from agents.base_agent import BaseAgent
from utils.output_parser import OutputFormatError


class VLMFeedback(BaseModel):
    """Data model for VLM critic feedback"""
    overall_assessment: str = ""
    specific_issues: List[str] = []
    improvement_suggestions: List[str] = []
    needs_improvement: bool = True
    confidence_score: float = 0.0

    @classmethod
    def extract_from_response(cls, text: str) -> "VLMFeedback":
        """
        Extract structured feedback from VLM response

        Args:
            text: Raw VLM response

        Returns:
            VLMFeedback instance with parsed feedback

        Raises:
            OutputFormatError: If response format is invalid
        """
        try:
            # Try to extract JSON format first
            import re
            json_match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
            if json_match:
                feedback_data = json.loads(json_match.group(1))
                return cls(**feedback_data)

            # Fallback to structured text parsing
            feedback = cls()
            feedback.overall_assessment = text

            # Extract specific sections if available
            sections = {
                'overall_assessment': r'Overall Assessment[:\s]*(.*?)(?=\n\n|\nSpecific Issues|\nImprovement Suggestions|\nNeeds Improvement|$)',
                'specific_issues': r'Specific Issues[:\s]*(.*?)(?=\n\n|\nImprovement Suggestions|\nNeeds Improvement|$)',
                'improvement_suggestions': r'Improvement Suggestions[:\s]*(.*?)(?=\n\n|\nNeeds Improvement|$)',
                'needs_improvement': r'Needs Improvement[:\s]*(.*?)(?=\n\n|$)',
                'confidence_score': r'Confidence Score[:\s]*([\d\.]+)'
            }

            for field, pattern in sections.items():
                match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
                if match:
                    value = match.group(1).strip()
                    if field == 'specific_issues' or field == 'improvement_suggestions':
                        # Parse list items
                        items = [item.strip('- ').strip() for item in value.split('\n') if item.strip() and item.strip().startswith('-')]
                        if not items:
                            items = [item.strip() for item in value.split('\n') if item.strip()]
                        setattr(feedback, field, items)
                    elif field == 'needs_improvement':
                        feedback.needs_improvement = value.lower() in ['true', 'yes', '1', 'needs improvement']
                    elif field == 'confidence_score':
                        try:
                            feedback.confidence_score = float(value)
                        except ValueError:
                            feedback.confidence_score = 0.5
                    else:
                        setattr(feedback, field, value)

            # If no specific issues found, try to extract from overall assessment
            if not feedback.specific_issues and feedback.overall_assessment:
                # Look for bullet points or numbered lists in overall assessment
                issue_patterns = [
                    r'(?:issues?|problems?|concerns?)[:\s]*\n((?:[-*]\s*.+\n?)+)',
                    r'(?:the following)[:\s]*\n((?:[-*]\s*.+\n?)+)',
                    r'((?:[-*]\s*.+\n?){2,})'  # Any list of 2+ bullet points
                ]

                for pattern in issue_patterns:
                    match = re.search(pattern, feedback.overall_assessment, re.IGNORECASE | re.MULTILINE)
                    if match:
                        items = [item.strip('- *').strip() for item in match.group(1).split('\n') if item.strip()]
                        feedback.specific_issues = items[:5]  # Limit to 5 issues
                        break

            return feedback

        except Exception as e:
            # Fallback: create basic feedback from raw text
            feedback = cls()
            feedback.overall_assessment = text
            feedback.needs_improvement = len(text) > 50  # Assume longer responses indicate issues
            feedback.confidence_score = 0.3  # Low confidence for unparsed responses
            return feedback


class VLMCriticAgent(BaseAgent):
    """
    Agent for analyzing 3D objects using Vision-Language Models.

    Analyzes rendered images of 3D objects and provides detailed feedback
    about design quality, proportions, and suggested improvements.
    """

    def __init__(self, config_manager):
        """
        Initialize the VLM Critic Agent.

        Args:
            config_manager: Configuration manager instance
        """
        self.logger = logging.getLogger(self.__class__.__name__)

        # Initialize with vlm_critic agent type
        super().__init__(config_manager, 'vlm_critic')

    def _load_system_prompt(self) -> str:
        """Load system prompt for VLM object analysis."""
        try:
            from prompt.vlm_critic import system_prompt
            return system_prompt
        except ImportError as e:
            self.logger.error(f"Failed to load VLM critic prompt: {e}")
            # Fallback prompt
            return """You are an expert 3D object design critic with extensive knowledge of product design,
            ergonomics, and mechanical engineering. Your task is to analyze 3D objects from multiple viewpoints
            and provide detailed, constructive feedback about their design quality.

            When analyzing an object, consider:
            - Overall proportions and scale relationships
            - Structural integrity and mechanical feasibility
            - Design consistency and aesthetic coherence
            - Functional aspects and usability
            - Manufacturing considerations
            - Detail quality and completeness

            Provide specific, actionable feedback that can guide improvements to the 3D model generation process."""

    def _format_user_prompt(self, input_data: Dict[str, Any]) -> str:
        """
        Format user prompt with object description and images.

        Args:
            input_data: Dictionary containing description, images, and object_json

        Returns:
            Formatted user prompt
        """
        description = input_data.get('description', '')
        object_json = input_data.get('object_json', {})
        iteration_num = input_data.get('iteration_num', 1)

        # Format object JSON information
        json_info = ""
        color_info = ""

        if object_json:
            # Extract color mapping if present
            color_mapping = object_json.get('color_mapping', {})
            if color_mapping:
                color_info = "\n\n**Component Color Mapping:**\n"
                for part_name, color_name in color_mapping.items():
                    color_info += f"- {part_name}: {color_name}\n"
                color_info += "\nUse these colors to identify specific components in your analysis."

                # Remove color_mapping from object_json for cleaner display
                object_json_clean = {k: v for k, v in object_json.items() if k != 'color_mapping'}
                json_info = f"\n\nDetailed Object Specification:\n{json.dumps(object_json_clean, indent=2, ensure_ascii=False)}"
            else:
                json_info = f"\n\nDetailed Object Specification:\n{json.dumps(object_json, indent=2, ensure_ascii=False)}"

        prompt = f"""Analyze this 3D object: {description}{json_info}{color_info}

Iteration {iteration_num}. BE CONCISE AND CRITICAL.

CHECK FOR:
1. FLOATING/DISCONNECTED parts (wheels off ground, handles detached, gaps)
2. MISSING components from description
3. MISALIGNED parts (offset, rotated wrong)
4. BAD PROPORTIONS (too big/small)
5. INTERSECTIONS (parts inside each other)

FORMAT:
Overall Assessment: [1 sentence max]

Specific Issues:
- [Issue with exact location/component]
- [Max 5 issues, 1 line each]

Improvement Suggestions:
- [Natural language: "Move handle closer to door"]
- [Clear directional fixes: "Raise wheels to touch ground"]

Needs Improvement: Yes/No

Confidence Score:
- 0.1-0.2: Floating parts
- 0.3-0.5: Major issues
- 0.6-0.8: Minor issues
- 0.9-1.0: Perfect

KEEP TOTAL RESPONSE UNDER 150 WORDS."""

        return prompt

    def parse_response(self, response: str) -> VLMFeedback:
        """
        Parse VLM response into structured feedback.

        Args:
            response: Raw VLM response

        Returns:
            VLMFeedback instance with parsed feedback

        Raises:
            OutputFormatError: If response format is invalid
        """
        try:
            return VLMFeedback.extract_from_response(response)
        except Exception as e:
            self.logger.error(f"Failed to parse VLM response: {e}")
            raise OutputFormatError(f"Invalid VLM response format: {e}")

    def _prepare_input_data(self, description: str, image_paths: List[str],
                           object_json: Optional[Dict] = None,
                           iteration_num: int = 1, **kwargs) -> Dict[str, Any]:
        """
        Prepare input data for VLM analysis.

        Args:
            description: Original object description
            image_paths: List of paths to rendered images
            object_json: Optional detailed object specification
            iteration_num: Current iteration number
            **kwargs: Additional arguments

        Returns:
            Dictionary of input data
        """
        return {
            'description': description,
            'image_paths': image_paths,
            'object_json': object_json or {},
            'iteration_num': iteration_num
        }

    def _encode_images(self, image_paths: List[str]) -> List[Dict[str, Any]]:
        """
        Encode images as base64 for VLM input.

        Args:
            image_paths: List of image file paths

        Returns:
            List of image data dictionaries
        """
        encoded_images = []

        for image_path in image_paths:
            if not os.path.exists(image_path):
                self.logger.warning(f"Image not found: {image_path}")
                continue

            try:
                with open(image_path, 'rb') as f:
                    image_data = base64.b64encode(f.read()).decode('utf-8')

                # Determine image type
                image_type = "image/png"
                if image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
                    image_type = "image/jpeg"

                encoded_images.append({
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:{image_type};base64,{image_data}",
                        "detail": "high"
                    }
                })

            except Exception as e:
                self.logger.error(f"Failed to encode image {image_path}: {e}")

        return encoded_images

    def generate_with_images(self, description: str, image_paths: List[str],
                           object_json: Optional[Dict] = None,
                           iteration_num: int = 1,
                           **kwargs) -> Tuple[VLMFeedback, bool, Dict[str, Any], str]:
        """
        Generate VLM feedback with image analysis.

        Args:
            description: Original object description
            image_paths: List of paths to rendered images
            object_json: Optional detailed object specification
            iteration_num: Current iteration number
            **kwargs: Additional arguments for generation

        Returns:
            Tuple of (feedback, success, metrics, raw_response)
        """
        # Prepare input data
        input_data = self._prepare_input_data(
            description, image_paths, object_json, iteration_num, **kwargs
        )

        # Encode images
        encoded_images = self._encode_images(image_paths)
        if not encoded_images:
            raise ValueError("No valid images provided for analysis")

        # Format text prompt
        text_prompt = self._format_user_prompt(input_data)

        # Prepare message with images for VLM
        messages = [
            {"role": "system", "content": self._load_system_prompt()},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": text_prompt}
                ] + encoded_images
            }
        ]

        # Generate response with VLM support
        try:
            # Check if this is a Gemini model (supports vision)
            if 'gemini' in self.model_name.lower():
                response, metrics = self._generate_with_gemini_vision(text_prompt, encoded_images, **kwargs)
            else:
                # For non-vision models, use text-only approach
                self.logger.warning(f"Model {self.model_name} may not support vision. Using text-only approach.")
                response, metrics = self._generate_text_only(text_prompt, image_paths, **kwargs)

            if not response:
                return None, False, metrics, ""

            # Parse response
            feedback = self.parse_response(response)

            return feedback, True, metrics, response

        except Exception as e:
            self.logger.error(f"VLM generation failed: {e}")
            return None, False, {}, str(e)

    def _generate_with_gemini_vision(self, text_prompt: str, encoded_images: List[Dict], **kwargs) -> Tuple[str, Dict]:
        """Generate response using Gemini with vision capabilities."""
        import time
        import google.generativeai as genai

        start_time = time.time()

        try:
            # Create content parts: text + images
            content_parts = [text_prompt]

            # Add images to content
            for image_data in encoded_images:
                # Extract base64 data and convert to PIL Image
                base64_data = image_data['image_url']['url'].split(',')[1]
                import base64
                from PIL import Image
                import io

                image_bytes = base64.b64decode(base64_data)
                pil_image = Image.open(io.BytesIO(image_bytes))
                content_parts.append(pil_image)

            # Generate response with multimodal content
            response = self.provider.model.generate_content(
                content_parts,
                generation_config={
                    'temperature': kwargs.get('temperature', self.provider.temperature),
                    'max_output_tokens': kwargs.get('max_tokens', self.provider.max_tokens),
                },
                safety_settings=self.provider.SAFETY_SETTINGS
            )

            # Extract response text
            if not response.candidates or not response.candidates[0].content:
                raise ValueError("Gemini returned empty response")

            response_text = response.text

            # Calculate metrics
            if hasattr(response, 'usage_metadata'):
                input_tokens = response.usage_metadata.prompt_token_count
                output_tokens = response.usage_metadata.candidates_token_count
                total_tokens = response.usage_metadata.total_token_count
            else:
                # Fallback estimation
                input_tokens = len(text_prompt.split()) * 2  # Rough estimation
                output_tokens = len(response_text.split()) * 2
                total_tokens = input_tokens + output_tokens

            metrics = {
                'time_cost': time.time() - start_time,
                'model': self.model_name,
                'agent_type': self.agent_type,
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'total_tokens': total_tokens,
                'cost': 0.0  # TODO: Calculate actual cost
            }

            return response_text, metrics

        except Exception as e:
            self.logger.error(f"Gemini vision generation failed: {e}")
            raise

    def _generate_text_only(self, text_prompt: str, image_paths: List[str], **kwargs) -> Tuple[str, Dict]:
        """Fallback to text-only generation for non-vision models."""
        # Modify prompt to mention that images cannot be analyzed
        modified_prompt = f"""{text_prompt}

Note: Images were provided but cannot be analyzed by this model. Please provide general feedback based on the description alone."""

        # Use standard provider interface
        start_time = time.time()
        response, provider_metrics = self.provider.invoke(
            prompt=modified_prompt,
            system_prompt=self._load_system_prompt(),
            **kwargs
        )

        metrics = {
            'time_cost': time.time() - start_time,
            'model': self.model_name,
            'agent_type': self.agent_type,
            **provider_metrics
        }

        return response, metrics

    def save_output(self, result: VLMFeedback, output_folder: str,
                   iteration_num: int = 1, metrics: Dict[str, Any] = None):
        """
        Save VLM feedback to file.

        Args:
            result: VLM feedback result
            output_folder: Directory to save output
            iteration_num: Current iteration number
            metrics: Generation metrics (optional)
        """
        if not result:
            return

        # Create iteration folder
        iteration_folder = os.path.join(output_folder, f"iteration_{iteration_num}")
        os.makedirs(iteration_folder, exist_ok=True)

        # Save feedback text
        feedback_path = os.path.join(iteration_folder, "vlm_feedback.txt")
        with open(feedback_path, 'w', encoding='utf-8') as f:
            f.write(f"Overall Assessment:\n{result.overall_assessment}\n\n")

            if result.specific_issues:
                f.write("Specific Issues:\n")
                for issue in result.specific_issues:
                    f.write(f"- {issue}\n")
                f.write("\n")

            if result.improvement_suggestions:
                f.write("Improvement Suggestions:\n")
                for suggestion in result.improvement_suggestions:
                    f.write(f"- {suggestion}\n")
                f.write("\n")

            f.write(f"Needs Improvement: {'Yes' if result.needs_improvement else 'No'}\n")
            f.write(f"Confidence Score: {result.confidence_score:.2f}\n")

        # Save structured feedback as JSON
        feedback_json_path = os.path.join(iteration_folder, "vlm_feedback.json")
        with open(feedback_json_path, 'w', encoding='utf-8') as f:
            json.dump(result.dict(), f, indent=2, ensure_ascii=False)

        self.logger.info(f"Saved VLM feedback to {feedback_path}")

    def analyze_object(self, description: str, image_paths: List[str],
                      object_json: Optional[Dict] = None,
                      output_folder: Optional[str] = None,
                      iteration_num: int = 1) -> Tuple[VLMFeedback, bool]:
        """
        Convenience method to analyze an object and save results.

        Args:
            description: Original object description
            image_paths: List of paths to rendered images
            object_json: Optional detailed object specification
            output_folder: Optional output folder for saving results
            iteration_num: Current iteration number

        Returns:
            Tuple of (feedback, success)
        """
        feedback, success, metrics, raw_response = self.generate_with_images(
            description, image_paths, object_json, iteration_num
        )

        if success and feedback and output_folder:
            self.save_output(feedback, output_folder, iteration_num, metrics)

        return feedback, success