"""
PointLLM Critic Agent for 3D Point Cloud Analysis

This agent uses PointLLM to analyze 3D objects from point cloud data,
providing geometric and structural feedback that complements 2D visual analysis.
"""

import os
import json
import logging
from typing import Dict, Any, List, Optional, Tuple
from pydantic import BaseModel

from agents.base_agent import BaseAgent
from agents.vlm_critic import VLMFeedback


class PointLLMCriticAgent(BaseAgent):
    """
    Agent for analyzing 3D objects using PointLLM on point cloud data.

    Provides structural and geometric analysis that complements 2D VLM feedback
    by understanding 3D spatial relationships and component connectivity.
    """

    def __init__(self, config_manager):
        """
        Initialize the PointLLM Critic Agent.

        Args:
            config_manager: Configuration manager instance
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.config_manager = config_manager

        # Get PointLLM configuration
        self.pointllm_config = config_manager.config.get('pointllm_critic', {})
        self.enabled = self.pointllm_config.get('enabled', False)

        if not self.enabled:
            self.logger.info("PointLLM Critic is disabled in configuration")
            self.provider = None
            self.converter = None
            return

        # Import dependencies only when enabled
        try:
            from utils.pointcloud_converter import PointCloudConverter
            from providers.pointllm_provider import PointLLMProvider, PointLLMProviderStub
        except ImportError as e:
            self.logger.error(f"Failed to import PointLLM dependencies: {e}")
            self.logger.info("Disabling PointLLM critic due to missing dependencies")
            self.enabled = False
            self.provider = None
            self.converter = None
            return

        # Initialize components
        self.sample_points = self.pointllm_config.get('sample_points', 8192)
        self.converter = PointCloudConverter(sample_points=self.sample_points)

        # Initialize provider (use stub if model not available)
        try:
            self.provider = PointLLMProvider(config_manager.config)
        except Exception as e:
            self.logger.warning(f"Failed to initialize PointLLM provider: {e}")
            self.logger.info("Using stub provider for testing")
            self.provider = PointLLMProviderStub(config_manager.config)

        # Initialize base agent with placeholder since PointLLM doesn't use API models
        try:
            super().__init__(config_manager, 'pointllm_critic')
        except ValueError:
            # PointLLM doesn't use API models, so this is expected
            pass

    def analyze_object(self,
                      mesh_path: str,
                      description: str,
                      object_json: Optional[Dict] = None,
                      output_folder: Optional[str] = None,
                      iteration_num: int = 1,
                      links_json_path: Optional[str] = None) -> Tuple[VLMFeedback, bool]:
        """
        Analyze a 3D object mesh using PointLLM.

        Args:
            mesh_path: Path to the combined OBJ mesh file
            description: Original object description
            object_json: Optional detailed object specification
            output_folder: Optional folder to save analysis results
            iteration_num: Current iteration number
            links_json_path: Optional path to links_hierarchy.json for link-based grouping

        Returns:
            Tuple of (feedback, success)
        """
        if not self.enabled:
            self.logger.warning("PointLLM Critic is disabled")
            return None, False

        try:
            self.logger.info(f"Analyzing mesh with PointLLM: {mesh_path}")

            # Convert mesh to point cloud with link-based grouping if available
            point_cloud, component_mapping, color_names = self.converter.convert_obj_to_pointcloud(
                mesh_path,
                links_json_path=links_json_path
            )

            # Create color mapping for prompt
            color_mapping = {}
            for idx, color_name in enumerate(color_names):
                if idx in component_mapping:
                    color_mapping[color_name] = component_mapping[idx]
                else:
                    color_mapping[color_name] = f"component_{idx}"

            # Create analysis prompt
            prompt = self._create_analysis_prompt(description, object_json, color_mapping, iteration_num)

            # Analyze with PointLLM
            analysis_text = self.provider.analyze_point_cloud(
                point_cloud=point_cloud,
                prompt=prompt,
                color_mapping=color_mapping
            )

            # Parse response into structured feedback
            feedback = self._parse_analysis(analysis_text)

            # Save results if output folder provided
            if output_folder:
                self._save_analysis(feedback, analysis_text, output_folder, iteration_num)

                # Save the colored point cloud for debugging
                point_cloud_path = os.path.join(output_folder, f"pointllm_input_iter{iteration_num}.npy")
                self.converter.save_colored_pointcloud(point_cloud, point_cloud_path)
                self.logger.info(f"Saved point cloud visualization to {point_cloud_path}")

            self.logger.info("PointLLM analysis completed successfully")
            return feedback, True

        except Exception as e:
            self.logger.error(f"PointLLM analysis failed: {e}")
            return None, False

    def _create_analysis_prompt(self,
                               description: str,
                               object_json: Optional[Dict],
                               color_mapping: Dict[str, str],
                               iteration_num: int) -> str:
        """
        Create the analysis prompt for PointLLM.

        Args:
            description: Object description
            object_json: Optional object specification
            color_mapping: Mapping of colors to component names
            iteration_num: Current iteration

        Returns:
            Formatted prompt for 3D analysis
        """
        prompt = f"""You are reviewing a 3D point cloud reconstruction of: {description}

This is iteration {iteration_num} of the design process. Your goal is to CRITICALLY evaluate the geometry and call out any structural or spatial problems that could break the real object. Do NOT simply describe the model—focus on failure cases and design flaws.

Each colored cluster represents a different articulated component or sub-assembly:
"""

        component_details = self._extract_component_details(object_json)

        # Add color mapping information
        for color_name, component_name in color_mapping.items():
            # Make color names more readable
            display_color = color_name.replace('_', ' ').upper()
            info = component_details.get(component_name, {})
            type_str = info.get('type')
            shape_desc = info.get('shape')
            position_desc = info.get('position')
            description_snippet = info.get('description')

            extra_bits = []
            if type_str:
                extra_bits.append(f"type={type_str}")
            if position_desc:
                extra_bits.append(position_desc)
            if shape_desc:
                extra_bits.append(shape_desc)
            if description_snippet and description_snippet not in extra_bits:
                extra_bits.append(description_snippet)

            if extra_bits:
                prompt += f"- {display_color}: {component_name} ({'; '.join(extra_bits)})\n"
            else:
                prompt += f"- {display_color}: {component_name}\n"

        # Add object specification if available
        if object_json:
            # Clean version without color mapping
            clean_json = {k: v for k, v in object_json.items() if k != 'color_mapping'}
            if clean_json:
                prompt += f"\nDetailed specification:\n{json.dumps(clean_json, indent=2)}\n"

        prompt += """
Respond ONLY with valid JSON (no prose outside JSON) following this schema:
{
  "overall_structure": "<critical summary of stability>",
  "detected_issues": [
    {
      "severity": "CRITICAL | MAJOR | MINOR",
      "component": "<component_name>",
      "color": "<COLOR>",
      "issue": "<what is wrong>",
      "evidence": "<geometric clues from point cloud>"
    }, ... (at least 2 entries even if you must note 'No critical issue observed')
  ],
  "fix_suggestions": [
    {
      "component": "<component_name>",
      "color": "<COLOR>",
      "action": "<specific geometry fix>",
      "reason": "<why this fix helps>"
    }
  ],
  "confidence": 0.0-1.0
}

Always attempt to find real faults: check for intersections, floating parts, impossible articulation, grounding problems, missing volume, inconsistent thickness, and misaligned hinges/sliders. If a section truly has no problems, include an entry explaining the verification that indicates the design is sound.
"""

        return prompt

    def _extract_component_details(self, object_json: Optional[Dict[str, Any]]) -> Dict[str, Dict[str, str]]:
        """Flatten hierarchy data so the prompt can reference detailed link info."""
        details: Dict[str, Dict[str, str]] = {}

        if not object_json:
            return details

        def record(node: Dict[str, Any]):
            name = node.get('name')
            if not name:
                return
            details[name] = {
                'type': node.get('type', ''),
                'description': node.get('description', ''),
                'shape': node.get('description_shape', ''),
                'position': node.get('description_position', ''),
            }
            for child in node.get('children', []) or []:
                if isinstance(child, dict):
                    record(child)

        # Handle possible structures
        if isinstance(object_json, dict):
            if 'hierarchy' in object_json and isinstance(object_json['hierarchy'], dict):
                for node in object_json['hierarchy'].get('structure', []) or []:
                    if isinstance(node, dict):
                        record(node)
            elif 'structure' in object_json:
                for node in object_json.get('structure', []) or []:
                    if isinstance(node, dict):
                        record(node)
            else:
                record(object_json)
        elif isinstance(object_json, list):
            for node in object_json:
                if isinstance(node, dict):
                    record(node)

        return details

    def _parse_analysis(self, analysis_text: str) -> VLMFeedback:
        """
        Parse PointLLM analysis text into structured feedback.

        Args:
            analysis_text: Raw text from PointLLM

        Returns:
            Structured VLMFeedback object
        """
        try:
            try:
                data = json.loads(analysis_text)
            except json.JSONDecodeError:
                data = None

            if isinstance(data, dict) and 'detected_issues' in data:
                overall = data.get('overall_structure', '').strip()
                issues_data = data.get('detected_issues', [])
                suggestions_data = data.get('fix_suggestions', [])
                confidence_score = float(data.get('confidence', 0.5)) if isinstance(data.get('confidence'), (int, float)) else 0.5

                specific_issues = []
                for issue in issues_data:
                    if isinstance(issue, dict):
                        severity = issue.get('severity', 'MAJOR').upper()
                        component = issue.get('component', 'unknown')
                        color = issue.get('color', '').upper()
                        desc = issue.get('issue', issue.get('description', '')).strip()
                        evidence = issue.get('evidence', '').strip()
                        text = f"[{severity}] {component} ({color})"
                        if desc:
                            text += f": {desc}"
                        if evidence:
                            text += f" Evidence: {evidence}"
                        specific_issues.append(f"[3D] {text.strip()}")

                improvement_suggestions = []
                for suggestion in suggestions_data:
                    if isinstance(suggestion, dict):
                        component = suggestion.get('component', 'unknown')
                        color = suggestion.get('color', '').upper()
                        action = suggestion.get('action', suggestion.get('suggestion', '')).strip()
                        reason = suggestion.get('reason', '').strip()
                        text = f"{component} ({color}): {action}"
                        if reason:
                            text += f" (Reason: {reason})"
                        improvement_suggestions.append(f"[3D] {text.strip()}")

                needs_improvement = len(specific_issues) > 0
                overall_assessment = f"[3D Analysis] {overall}" if overall else "[3D Analysis] 3D point cloud analysis completed."

                return VLMFeedback(
                    overall_assessment=overall_assessment,
                    specific_issues=specific_issues,
                    improvement_suggestions=improvement_suggestions,
                    needs_improvement=needs_improvement,
                    confidence_score=max(0.0, min(1.0, confidence_score))
                )

            # Fallback to legacy parsing if JSON not provided
            overall_assessment = ""
            specific_issues = []
            improvement_suggestions = []
            confidence_score = 0.5

            lines = analysis_text.split('\n')
            current_section = None
            for line in lines:
                line = line.strip()
                if not line:
                    continue

                line_lower = line.lower()
                if 'overall' in line_lower and ('structure' in line_lower or 'assessment' in line_lower):
                    current_section = 'overall'
                    continue
                elif 'issue' in line_lower or 'problem' in line_lower:
                    current_section = 'issues'
                    continue
                elif 'suggestion' in line_lower or 'improvement' in line_lower or 'recommend' in line_lower:
                    current_section = 'suggestions'
                    continue
                elif 'connectivity' in line_lower or 'component' in line_lower:
                    current_section = 'issues'
                    continue

                if current_section == 'overall':
                    overall_assessment += line + " "
                elif current_section == 'issues':
                    if line.startswith(('-', '*', '•')):
                        specific_issues.append(line.lstrip('-*• ').strip())
                    elif line and not line.endswith(':'):
                        specific_issues.append(line)
                elif current_section == 'suggestions':
                    if line.startswith(('-', '*', '•')):
                        improvement_suggestions.append(line.lstrip('-*• ').strip())
                    elif line and not line.endswith(':'):
                        improvement_suggestions.append(line)

            if not overall_assessment and not specific_issues:
                overall_assessment = analysis_text[:500]

            overall_assessment = overall_assessment.strip() or "3D point cloud analysis completed."
            needs_improvement = len(specific_issues) > 0
            if len(specific_issues) > 3:
                confidence_score = 0.3
            elif len(specific_issues) > 0:
                confidence_score = 0.5
            else:
                confidence_score = 0.8

            overall_assessment = f"[3D Analysis] {overall_assessment}"
            specific_issues = [f"[3D] {issue}" for issue in specific_issues]
            improvement_suggestions = [f"[3D] {suggestion}" for suggestion in improvement_suggestions]

            return VLMFeedback(
                overall_assessment=overall_assessment,
                specific_issues=specific_issues,
                improvement_suggestions=improvement_suggestions,
                needs_improvement=needs_improvement,
                confidence_score=confidence_score
            )

        except Exception as e:
            self.logger.error(f"Failed to parse PointLLM analysis: {e}")

            return VLMFeedback(
                overall_assessment="[3D Analysis] Analysis completed with parsing errors.",
                specific_issues=["[3D] Unable to parse detailed issues from analysis"],
                improvement_suggestions=["[3D] Manual review of 3D structure recommended"],
                needs_improvement=True,
                confidence_score=0.3
            )

    def _save_analysis(self,
                      feedback: VLMFeedback,
                      raw_text: str,
                      output_folder: str,
                      iteration_num: int):
        """
        Save PointLLM analysis results.

        Args:
            feedback: Structured feedback
            raw_text: Raw analysis text
            output_folder: Folder to save results
            iteration_num: Current iteration
        """
        try:
            # Save raw analysis
            raw_path = os.path.join(output_folder, f"pointllm_analysis_iter{iteration_num}.txt")
            with open(raw_path, 'w', encoding='utf-8') as f:
                f.write(raw_text)

            # Save structured feedback
            feedback_path = os.path.join(output_folder, f"pointllm_feedback_iter{iteration_num}.json")
            with open(feedback_path, 'w', encoding='utf-8') as f:
                json.dump({
                    'overall_assessment': feedback.overall_assessment,
                    'specific_issues': feedback.specific_issues,
                    'improvement_suggestions': feedback.improvement_suggestions,
                    'needs_improvement': feedback.needs_improvement,
                    'confidence_score': feedback.confidence_score
                }, f, indent=2, ensure_ascii=False)

            self.logger.info(f"Saved PointLLM analysis to {output_folder}")

        except Exception as e:
            self.logger.warning(f"Failed to save PointLLM analysis: {e}")

    def _load_system_prompt(self) -> str:
        """Not used directly, but required by BaseAgent."""
        return "You are a 3D geometry expert analyzing point cloud data."

    def _format_user_prompt(self, input_data: Dict[str, Any]) -> str:
        """Not used directly, but required by BaseAgent."""
        return ""

    def parse_response(self, response: str) -> Any:
        """Not used directly, but required by BaseAgent."""
        return self._parse_analysis(response)

    def _prepare_input_data(self, **kwargs) -> Dict[str, Any]:
        """Not used directly, but required by BaseAgent."""
        return kwargs

    def save_output(self, result: Any, output_folder: str, **kwargs):
        """Not used directly, but required by BaseAgent."""
        pass
