"""
Feedback Fusion System for Combining 2D VLM and 3D PointLLM Analysis

This module merges feedback from visual (2D) and geometric (3D) analysis
to provide comprehensive improvement suggestions for 3D object generation.
"""

import logging
from typing import Optional, List
from agents.vlm_critic import VLMFeedback


class FeedbackFusion:
    """
    Merges feedback from multiple analysis sources (2D VLM and 3D PointLLM).

    Combines insights from different perspectives to provide more comprehensive
    and accurate feedback for shape improvement.
    """

    def __init__(self, config: dict = None):
        """
        Initialize the feedback fusion system.

        Args:
            config: Configuration dictionary
        """
        self.logger = logging.getLogger(self.__class__.__name__)
        self.config = config or {}

    def merge_feedback(self,
                      vlm_feedback: Optional[VLMFeedback],
                      pointllm_feedback: Optional[VLMFeedback]) -> VLMFeedback:
        """
        Merge 2D VLM and 3D PointLLM feedback into unified feedback.

        Args:
            vlm_feedback: Feedback from 2D visual analysis (VLM)
            pointllm_feedback: Feedback from 3D geometric analysis (PointLLM)

        Returns:
            Merged VLMFeedback combining both perspectives
        """
        # Handle cases where one feedback is missing
        if not vlm_feedback and not pointllm_feedback:
            self.logger.warning("No feedback to merge")
            return VLMFeedback(
                overall_assessment="No analysis available",
                specific_issues=[],
                improvement_suggestions=[],
                needs_improvement=False,
                confidence_score=0.0
            )

        if not vlm_feedback:
            self.logger.info("Only PointLLM feedback available")
            return pointllm_feedback

        if not pointllm_feedback:
            self.logger.info("Only VLM feedback available")
            return vlm_feedback

        # Both feedbacks available - merge them
        self.logger.info("Merging 2D and 3D feedback")

        # Merge overall assessments
        overall_assessment = self._merge_overall_assessments(
            vlm_feedback.overall_assessment,
            pointllm_feedback.overall_assessment
        )

        # Merge specific issues (combine and deduplicate)
        specific_issues = self._merge_issues(
            vlm_feedback.specific_issues,
            pointllm_feedback.specific_issues
        )

        # Merge improvement suggestions
        improvement_suggestions = self._merge_suggestions(
            vlm_feedback.improvement_suggestions,
            pointllm_feedback.improvement_suggestions
        )

        # Determine if improvement is needed (conservative approach)
        # If either analysis says improvement is needed, we need improvement
        needs_improvement = vlm_feedback.needs_improvement or pointllm_feedback.needs_improvement

        # Calculate combined confidence (conservative - use lower value)
        confidence_score = min(
            vlm_feedback.confidence_score,
            pointllm_feedback.confidence_score
        )

        return VLMFeedback(
            overall_assessment=overall_assessment,
            specific_issues=specific_issues,
            improvement_suggestions=improvement_suggestions,
            needs_improvement=needs_improvement,
            confidence_score=confidence_score
        )

    def _merge_overall_assessments(self, vlm_assessment: str, pointllm_assessment: str) -> str:
        """
        Merge overall assessments from both sources.

        Args:
            vlm_assessment: 2D visual assessment
            pointllm_assessment: 3D geometric assessment

        Returns:
            Combined assessment text
        """
        # Combine both assessments with clear labeling
        merged = "=== Combined Analysis ===\n\n"

        # Add 2D assessment (remove [2D] prefix if present)
        vlm_text = vlm_assessment.replace("[2D Analysis]", "").replace("[2D]", "").strip()
        merged += "**Visual Analysis (2D):**\n"
        merged += vlm_text + "\n\n"

        # Add 3D assessment (remove [3D] prefix if present)
        pointllm_text = pointllm_assessment.replace("[3D Analysis]", "").replace("[3D]", "").strip()
        merged += "**Geometric Analysis (3D):**\n"
        merged += pointllm_text

        return merged

    def _merge_issues(self,
                     vlm_issues: List[str],
                     pointllm_issues: List[str]) -> List[str]:
        """
        Merge and deduplicate issues from both sources.

        Args:
            vlm_issues: Issues from 2D analysis
            pointllm_issues: Issues from 3D analysis

        Returns:
            Combined list of issues
        """
        merged_issues = []

        # Add 2D issues (ensure they're labeled)
        for issue in vlm_issues:
            if not issue.startswith("[2D]"):
                issue = f"[2D] {issue}"
            merged_issues.append(issue)

        # Add 3D issues (ensure they're labeled)
        for issue in pointllm_issues:
            if not issue.startswith("[3D]"):
                issue = f"[3D] {issue}"
            merged_issues.append(issue)

        # Remove exact duplicates while preserving order
        seen = set()
        unique_issues = []
        for issue in merged_issues:
            # Normalize for comparison (remove prefixes and lowercase)
            normalized = issue.replace("[2D]", "").replace("[3D]", "").strip().lower()
            if normalized not in seen:
                seen.add(normalized)
                unique_issues.append(issue)

        # Prioritize critical issues (disconnected/floating parts)
        critical_keywords = ['disconnect', 'float', 'detach', 'gap', 'separate']
        critical_issues = []
        normal_issues = []

        for issue in unique_issues:
            is_critical = any(keyword in issue.lower() for keyword in critical_keywords)
            if is_critical:
                critical_issues.append(issue)
            else:
                normal_issues.append(issue)

        # Return critical issues first, then normal issues
        return critical_issues + normal_issues

    def _merge_suggestions(self,
                          vlm_suggestions: List[str],
                          pointllm_suggestions: List[str]) -> List[str]:
        """
        Merge improvement suggestions from both sources.

        Args:
            vlm_suggestions: Suggestions from 2D analysis
            pointllm_suggestions: Suggestions from 3D analysis

        Returns:
            Combined list of suggestions
        """
        merged_suggestions = []

        # Add 2D suggestions (ensure they're labeled)
        for suggestion in vlm_suggestions:
            if not suggestion.startswith("[2D]"):
                suggestion = f"[2D] {suggestion}"
            merged_suggestions.append(suggestion)

        # Add 3D suggestions (ensure they're labeled)
        for suggestion in pointllm_suggestions:
            if not suggestion.startswith("[3D]"):
                suggestion = f"[3D] {suggestion}"
            merged_suggestions.append(suggestion)

        # Remove duplicates while preserving order
        seen = set()
        unique_suggestions = []
        for suggestion in merged_suggestions:
            # Normalize for comparison
            normalized = suggestion.replace("[2D]", "").replace("[3D]", "").strip().lower()
            if normalized not in seen:
                seen.add(normalized)
                unique_suggestions.append(suggestion)

        # Group related suggestions
        connectivity_suggestions = []
        position_suggestions = []
        proportion_suggestions = []
        other_suggestions = []

        for suggestion in unique_suggestions:
            suggestion_lower = suggestion.lower()
            if any(word in suggestion_lower for word in ['connect', 'attach', 'join', 'link']):
                connectivity_suggestions.append(suggestion)
            elif any(word in suggestion_lower for word in ['position', 'move', 'align', 'place']):
                position_suggestions.append(suggestion)
            elif any(word in suggestion_lower for word in ['scale', 'size', 'proportion', 'dimension']):
                proportion_suggestions.append(suggestion)
            else:
                other_suggestions.append(suggestion)

        # Return grouped suggestions
        final_suggestions = []

        if connectivity_suggestions:
            final_suggestions.append("=== Connectivity Fixes ===")
            final_suggestions.extend(connectivity_suggestions)

        if position_suggestions:
            if final_suggestions:
                final_suggestions.append("")  # Add spacing
            final_suggestions.append("=== Position Adjustments ===")
            final_suggestions.extend(position_suggestions)

        if proportion_suggestions:
            if final_suggestions:
                final_suggestions.append("")  # Add spacing
            final_suggestions.append("=== Proportion Corrections ===")
            final_suggestions.extend(proportion_suggestions)

        if other_suggestions:
            if final_suggestions:
                final_suggestions.append("")  # Add spacing
            final_suggestions.append("=== Other Improvements ===")
            final_suggestions.extend(other_suggestions)

        return final_suggestions


def create_feedback_fusion(config_manager) -> Optional[FeedbackFusion]:
    """
    Create a feedback fusion instance based on configuration.

    Args:
        config_manager: Configuration manager

    Returns:
        FeedbackFusion instance if enabled, None otherwise
    """
    fusion_config = config_manager.config.get('feedback_fusion', {})

    if not fusion_config.get('enabled', False):
        logging.info("Feedback fusion is disabled")
        return None

    logging.info("Creating feedback fusion system")
    return FeedbackFusion(fusion_config)