#!/usr/bin/env python3
"""
Script to test Gemini model's ability to identify face attributes from generated images.
Uses multiple choice questions to test attribute recognition accuracy.

Usage:
    python check_attributes.py --metadata-file faces/faces4/face_metadata.json --output-dir results
"""

import requests
import json
import os
import argparse
import base64
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import time
import random
from dotenv import load_dotenv

# Load environment variables from .env file in the top directory
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

# Get OpenRouter API key from environment
API_KEY_REF = os.getenv("OPENROUTER_API_KEY", "your-api-key-here")

class AttributeChecker:
    def __init__(self, api_key: str = API_KEY_REF, log_file: Optional[str] = None):
        """Initialize the attribute checker with API key."""
        self.api_key = api_key
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
        # Define all possible attribute options
        self.attribute_options = {
            "eye_color": ["dark_brown", "light_brown", "green", "blue", "gray", "red"],
            "hair_color": ["black", "brown", "light_brown", "blonde", "red", "gray_white"],
            "hair_style": ["buzz_short", "sholder_lenght_straight", "sholder_length_curly_coily", "long_curly_locs", "long_straight"],
            "accessories": ["none", "eyeglasses_clear", "earrings_visible", "headband", "scarf_neck_face"],
            "gender": ["male", "female"]
        }
        
        # Setup logging
        self.log_file = log_file
        if self.log_file:
            # Create log file and write header
            with open(self.log_file, 'w') as f:
                f.write("=== GEMINI ATTRIBUTE TESTING LOG ===\n")
                f.write(f"Started at: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
                f.write("=" * 50 + "\n\n")
    
    def log_to_file(self, message: str):
        """Write message to log file if logging is enabled."""
        if self.log_file:
            with open(self.log_file, 'a') as f:
                f.write(f"{message}\n")
    
    def encode_image_to_base64(self, image_path: str) -> str:
        """Encode image to base64 for API."""
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    
    def create_multiple_choice_question(self, attribute_type: str, correct_answer: str) -> Tuple[str, List[str]]:
        """Create a multiple choice question for a specific attribute."""
        options = self.attribute_options[attribute_type].copy()
        
        # Ensure correct answer is in options (should always be)
        if correct_answer not in options:
            print(f"Warning: Correct answer '{correct_answer}' not in options for {attribute_type}")
            options.append(correct_answer)
        
        # Shuffle options to randomize order
        random.shuffle(options)
        
        # Create question text
        question_map = {
            "eye_color": "What is the eye color of the person in this image?",
            "hair_color": "What is the hair color of the person in this image?",
            "hair_style": "What is the hair style of the person in this image?",
            "accessories": "What accessory is the person wearing in this image?",
            "gender": "What is the gender of the person in this image?"
        }
        
        question = question_map[attribute_type]
        
        # Format options as A, B, C, D, E
        formatted_options = []
        for i, option in enumerate(options):
            letter = chr(65 + i)  # A, B, C, D, E
            formatted_options.append(f"{letter}) {option}")
        
        return question, formatted_options, options
    
    def ask_attribute_question(self, image_path: str, attribute_type: str, correct_answer: str) -> Dict:
        """Ask Gemini to identify a specific attribute from the image."""
        
        # Encode image
        base64_image = self.encode_image_to_base64(image_path)
        
        # Create multiple choice question
        question, formatted_options, options = self.create_multiple_choice_question(attribute_type, correct_answer)
        
        # Find correct answer index
        correct_index = options.index(correct_answer)
        correct_letter = chr(65 + correct_index)
        
        # Create prompt
        options_text = "\n".join(formatted_options)
        prompt = f"""Look at this image and answer the following multiple choice question:

{question}

{options_text}

Please respond with only the letter of your answer (A, B, C, D, or E)."""
        
        payload = {
            "model": "google/gemini-2.5-flash-image-preview",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            }
                        }
                    ]
                }
            ],
            "modalities": ["image", "text"]
        }
        
        # Log model input
        input_log = f"    MODEL INPUT:\n       Question: {question}\n       Options: {formatted_options}\n       Correct answer: {correct_answer} ({correct_letter})\n       Image: {os.path.basename(image_path)}"
        print(input_log)
        self.log_to_file(input_log)
        
        try:
            response = requests.post(self.url, headers=self.headers, json=payload)
            response.raise_for_status()
            result = response.json()
            
            # Log model output
            if result.get("choices"):
                message = result["choices"][0]["message"]
                answer = message.get("content", "").strip().upper()
                full_response = message.get("content", "").strip()
                
                output_log = f"    MODEL OUTPUT:\n       Full response: '{full_response}'\n       Extracted answer: '{answer}'"
                print(output_log)
                self.log_to_file(output_log)
                
                # Check if answer is correct
                is_correct = answer == correct_letter
                
                result_log = f"    RESULT: {'CORRECT' if is_correct else 'INCORRECT'} ({answer} vs {correct_letter})"
                print(result_log)
                self.log_to_file(result_log)
                
                return {
                    "attribute_type": attribute_type,
                    "correct_answer": correct_answer,
                    "correct_letter": correct_letter,
                    "model_answer": answer,
                    "model_full_response": full_response,
                    "is_correct": is_correct,
                    "options": options,
                    "formatted_options": formatted_options,
                    "model_input": {
                        "question": question,
                        "options": formatted_options,
                        "image_file": os.path.basename(image_path)
                    }
                }
            else:
                error_log = "    ERROR: No response from model"
                print(error_log)
                self.log_to_file(error_log)
                return {"error": "No response from model"}
                
        except Exception as e:
            error_log = f"    ERROR: Request failed: {e}"
            print(error_log)
            self.log_to_file(error_log)
            return {"error": f"Request failed: {e}"}
    
    def test_all_attributes(self, image_path: str, ground_truth: Dict) -> Dict:
        """Test all attributes for a single image."""
        results = {
            "image_path": image_path,
            "ground_truth": ground_truth,
            "attribute_results": {}
        }
        
        # Test each attribute
        for attribute_type in self.attribute_options.keys():
            if attribute_type in ground_truth:
                print(f"  Testing {attribute_type}: {ground_truth[attribute_type]}")
                
                result = self.ask_attribute_question(
                    image_path, 
                    attribute_type, 
                    ground_truth[attribute_type]
                )
                
                results["attribute_results"][attribute_type] = result
                
                # Add delay between requests
                time.sleep(1)
        
        return results
    
    def test_multiple_images(self, metadata_list: List[Dict], max_images: Optional[int] = None) -> List[Dict]:
        """Test attributes for multiple images."""
        if max_images:
            metadata_list = metadata_list[:max_images]
        
        all_results = []
        
        for i, metadata in enumerate(metadata_list):
            image_path = metadata["image_path"]
            attributes = metadata["attributes"]
            
            print(f"Testing image {i+1}/{len(metadata_list)}: {metadata['image_name']}")
            
            if os.path.exists(image_path):
                result = self.test_all_attributes(image_path, attributes)
                all_results.append(result)
            else:
                print(f"  Image not found: {image_path}")
                all_results.append({
                    "image_path": image_path,
                    "error": "Image file not found",
                    "ground_truth": attributes
                })
        
        return all_results

def load_metadata(metadata_file: str) -> List[Dict]:
    """Load metadata from JSON file."""
    with open(metadata_file, 'r') as f:
        return json.load(f)

def save_results(results: List[Dict], output_dir: str):
    """Save test results to JSON file."""
    os.makedirs(output_dir, exist_ok=True)
    
    # Save detailed results
    results_path = os.path.join(output_dir, "attribute_test_results.json")
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Calculate and save summary statistics
    summary = calculate_summary_stats(results)
    summary_path = os.path.join(output_dir, "attribute_test_summary.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"Results saved to: {results_path}")
    print(f"Summary saved to: {summary_path}")

def calculate_summary_stats(results: List[Dict]) -> Dict:
    """Calculate summary statistics from test results."""
    total_images = len(results)
    attribute_stats = {}
    
    # Initialize counters for each attribute
    for attribute_type in ["eye_color", "hair_color", "hair_style", "accessories", "gender"]:
        attribute_stats[attribute_type] = {
            "total_tests": 0,
            "correct": 0,
            "accuracy": 0.0
        }
    
    # Count results
    for result in results:
        if "attribute_results" in result:
            for attribute_type, attr_result in result["attribute_results"].items():
                if "is_correct" in attr_result:
                    attribute_stats[attribute_type]["total_tests"] += 1
                    if attr_result["is_correct"]:
                        attribute_stats[attribute_type]["correct"] += 1
    
    # Calculate accuracies
    for attribute_type in attribute_stats:
        if attribute_stats[attribute_type]["total_tests"] > 0:
            attribute_stats[attribute_type]["accuracy"] = (
                attribute_stats[attribute_type]["correct"] / 
                attribute_stats[attribute_type]["total_tests"]
            )
    
    # Overall statistics
    total_tests = sum(stats["total_tests"] for stats in attribute_stats.values())
    total_correct = sum(stats["correct"] for stats in attribute_stats.values())
    overall_accuracy = total_correct / total_tests if total_tests > 0 else 0.0
    
    return {
        "total_images_tested": total_images,
        "total_attribute_tests": total_tests,
        "total_correct": total_correct,
        "overall_accuracy": overall_accuracy,
        "attribute_breakdown": attribute_stats
    }

def main():
    parser = argparse.ArgumentParser(
        description="Test Gemini model's ability to identify face attributes from images",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Test all images in metadata file
  python check_attributes.py --metadata-file faces/faces4/face_metadata.json --output-dir results
  
  # Test only first 5 images
  python check_attributes.py --metadata-file faces/faces4/face_metadata.json --output-dir results --max-images 5
  
  # Test with custom API key
  python check_attributes.py --metadata-file faces/faces4/face_metadata.json --output-dir results --api-key your-key
        """
    )
    
    parser.add_argument("--metadata-file", required=True,
                       help="Path to face_metadata.json file")
    parser.add_argument("--output-dir", required=True,
                       help="Output directory for test results")
    parser.add_argument("--max-images", type=int,
                       help="Maximum number of images to test (default: all)")
    parser.add_argument("--api-key", 
                       help="OpenRouter API key (or set OPENROUTER_API_KEY env var)")
    
    args = parser.parse_args()
    
    # Get API key
    api_key = args.api_key or API_KEY_REF
    if api_key == "your-api-key-here":
        print("Error: Please set your OpenRouter API key using --api-key or OPENROUTER_API_KEY environment variable")
        return
    
    # Load metadata
    print(f"Loading metadata from: {args.metadata_file}")
    metadata_list = load_metadata(args.metadata_file)
    print(f"Found {len(metadata_list)} images in metadata")
    
    # Create output directory first
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Create log file path
    log_file = os.path.join(args.output_dir, "model_interactions.log")
    
    # Initialize checker
    checker = AttributeChecker(api_key, log_file)
    
    # Test images
    print(f"Testing attributes for images...")
    print(f"Logging model interactions to: {log_file}")
    if args.max_images:
        print(f"Limiting to first {args.max_images} images")
    
    results = checker.test_multiple_images(metadata_list, args.max_images)
    
    # Save results
    save_results(results, args.output_dir)
    
    # Print summary
    summary = calculate_summary_stats(results)
    print(f"\n=== TEST SUMMARY ===")
    print(f"Images tested: {summary['total_images_tested']}")
    print(f"Total attribute tests: {summary['total_attribute_tests']}")
    print(f"Overall accuracy: {summary['overall_accuracy']:.2%}")
    print(f"\nPer-attribute accuracy:")
    for attr, stats in summary['attribute_breakdown'].items():
        if stats['total_tests'] > 0:
            print(f"  {attr}: {stats['correct']}/{stats['total_tests']} ({stats['accuracy']:.2%})")

if __name__ == "__main__":
    main()
