"""
Poster Judge Evaluator
Evaluate generated posters against reference using LLM judge
"""

import os
import sys
import base64
from pathlib import Path
from typing import Dict, Optional

# Add original src directory to path
_original_src_path = Path(__file__).parent.parent.parent / "src"
if str(_original_src_path) not in sys.path:
    sys.path.insert(0, str(_original_src_path))

from multiagent.llm import LLM
from langchain_core.messages import HumanMessage


class PosterJudgeEvaluator:
    """Evaluate poster similarity using LLM judge"""
    
    def __init__(self, config, model_version="gpt41"):
        """
        Initialize evaluator
        
        Args:
            config: Configuration object
            model_version: Model version ("gpt41", "gpt52", or "gemini25")
        """
        self.config = config
        self.model_version = model_version
        self.llm = LLM(config, model_version=model_version)
        self._load_evaluation_prompt()
    
    def _load_evaluation_prompt(self):
        """Load evaluation prompt"""
        prompts_dir = Path(__file__).parent.parent / "prompts"
        prompt_path = prompts_dir / "poster_judge_evaluation_similarity.txt"
        
        if prompt_path.exists():
            with open(prompt_path, 'r', encoding='utf-8') as f:
                self.evaluation_prompt_template = f.read()
        else:
            self.evaluation_prompt_template = """Evaluate visual similarity between a generated poster and a reference poster across five dimensions, using integer scores 1-5."""
    
    def _encode_image(self, image_path: str) -> str:
        """Encode image to base64"""
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    
    def evaluate_similarity(
        self,
        reference_image_path: str,
        generated_image_path: str,
        item_description: Optional[str] = None
    ) -> Dict:
        """
        Evaluate similarity between reference and generated poster
        
        Args:
            reference_image_path: Path to reference poster image
            generated_image_path: Path to generated poster image
            item_description: Optional item description for context
            
        Returns:
            Dictionary with scores and feedback
        """
        if not os.path.exists(reference_image_path):
            raise FileNotFoundError(f"Reference image not found: {reference_image_path}")
        if not os.path.exists(generated_image_path):
            raise FileNotFoundError(f"Generated image not found: {generated_image_path}")
        
        # Encode images
        ref_image_b64 = self._encode_image(reference_image_path)
        gen_image_b64 = self._encode_image(generated_image_path)
        
        # Build prompt
        prompt = self.evaluation_prompt_template
        if item_description:
            prompt += f"\n\nItem Description: {item_description}"
        
        # Create multimodal message
        messages = [HumanMessage(
            content=[
                {"type": "text", "text": prompt},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{ref_image_b64}"}
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{gen_image_b64}"}
                }
            ]
        )]
        
        # Get evaluation
        response = self.llm.llm.invoke(messages)
        
        # Parse JSON response
        import json
        import re
        content = response.content
        json_match = re.search(r'\{.*\}', content, re.DOTALL)
        if json_match:
            evaluation = json.loads(json_match.group(0))
            
            # Ensure scores are integers (1-5) and calculate average if missing
            if "scores" in evaluation:
                scores = evaluation["scores"]
                # Convert to integers and clamp to 1-5 range
                def clamp_score(score):
                    score_int = int(round(float(score)))
                    return max(1, min(5, score_int))
                
                clamped_scores = {}
                for key in ["Overall_Color", "Layout_Composition", "Button_Style", "Image_Content", "Text_Content"]:
                    if key in scores:
                        clamped_scores[key] = clamp_score(scores[key])
                    else:
                        clamped_scores[key] = 3  # Default
                
                evaluation["scores"] = clamped_scores
                
                # Calculate average if missing
                if "average_score" not in evaluation:
                    score_values = list(clamped_scores.values())
                    evaluation["average_score"] = sum(score_values) / len(score_values)
            
            return evaluation
        else:
            # Fallback
            return {
                "scores": {
                    "Overall_Color": 3,
                    "Layout_Composition": 3,
                    "Button_Style": 3,
                    "Image_Content": 3,
                    "Text_Content": 3
                },
                "average_score": 3.0,
                "feedback": {}
            }


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Poster Judge Evaluator")
    parser.add_argument("config_path", help="Config file path")
    parser.add_argument("reference_image", help="Reference image path")
    parser.add_argument("generated_image", help="Generated image path")
    parser.add_argument("--item-description", default="", help="Item description")
    parser.add_argument("--model-version", default="gpt41", choices=["gpt41", "gpt52", "gemini25"])
    
    args = parser.parse_args()
    
    from html_ad_workflow import load_config
    config = load_config()
    
    evaluator = PosterJudgeEvaluator(config, model_version=args.model_version)
    result = evaluator.evaluate_similarity(
        reference_image_path=args.reference_image,
        generated_image_path=args.generated_image,
        item_description=args.item_description
    )
    
    print(json.dumps(result, indent=2, ensure_ascii=False))

