#!/usr/bin/env python3
"""
Script to generate face images using OpenRouter API with Gemini 2.5 Flash Image Preview.
Generates faces with specific characteristics and saves them locally.

Usage:
    python generate_faces.py --num-faces 10 --output-dir faces_output
"""

import requests
import json
import os
import argparse
import base64
from pathlib import Path
from typing import List, Dict, Optional
import time
import random
from dotenv import load_dotenv

# Load environment variables from .env file in the top directory
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

# Get OpenRouter API key from environment
API_KEY_REF = os.getenv("OPENROUTER_API_KEY", "your-api-key-here")

class FaceGenerator:
    def __init__(self, api_key: str = API_KEY_REF):
        """Initialize the face generator with API key."""
        self.api_key = api_key
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        
    def generate_face(self, characteristics: str, output_path: str) -> bool:
        """
        Generate a single face image with given characteristics.
        
        Args:
            characteristics: Description of face characteristics
            output_path: Path to save the generated image
            
        Returns:
            bool: True if successful, False otherwise
        """
        payload = {
            "model": "google/gemini-2.5-flash-image-preview",
            "messages": [
                {
                    "role": "user",
                    "content": f"Generate a realistic portrait photo of a person with these characteristics: {characteristics}. Make it high quality, professional headshot style, good lighting, clear facial features."
                }
            ],
            "modalities": ["image", "text"]
        }
        
        try:
            response = requests.post(self.url, headers=self.headers, json=payload)
            response.raise_for_status()
            result = response.json()
            
            # Extract image from response
            if result.get("choices"):
                message = result["choices"][0]["message"]
                if message.get("images"):
                    for image in message["images"]:
                        image_url = image["image_url"]["url"]
                        
                        # Handle base64 data URL
                        if image_url.startswith("data:image"):
                            # Extract base64 data
                            header, data = image_url.split(",", 1)
                            image_data = base64.b64decode(data)
                            
                            # Save image
                            with open(output_path, "wb") as f:
                                f.write(image_data)
                            
                            print(f"Generated face saved to: {output_path}")
                            return True
                        else:
                            print(f"Unexpected image URL format: {image_url[:50]}...")
                            return False
                else:
                    print("No images found in response")
                    return False
            else:
                print("No choices found in response")
                return False
                
        except requests.exceptions.RequestException as e:
            print(f"Request failed: {e}")
            return False
        except Exception as e:
            print(f"Error generating face: {e}")
            return False
    
    def generate_multiple_faces(self, face_descriptions: List, output_dir: str) -> List[str]:
        """
        Generate multiple face images.
        
        Args:
            face_descriptions: List of (name, description) tuples or just descriptions
            output_dir: Directory to save images
            
        Returns:
            List of successfully generated image paths
        """
        os.makedirs(output_dir, exist_ok=True)
        successful_paths = []
        
        for i, item in enumerate(face_descriptions):
            # Handle both (name, description) tuples and just descriptions
            if isinstance(item, tuple):
                name, description = item
                
                # Extract attribute type from name (e.g., "eye_color_blue_male" -> "eye_color")
                if '_' in name:
                    parts = name.split('_')
                    if len(parts) >= 3:  # Format: attribute_type_option_gender
                        attribute_type = parts[0] + '_' + parts[1]  # Gets "eye_color", "hair_color", etc.
                        attribute_folder = os.path.join(output_dir, attribute_type)
                        os.makedirs(attribute_folder, exist_ok=True)
                        output_path = os.path.join(attribute_folder, f"{name}.png")
                    else:
                        output_path = os.path.join(output_dir, f"{name}.png")
                else:
                    output_path = os.path.join(output_dir, f"{name}.png")
            else:
                description = item
                output_path = os.path.join(output_dir, f"face_{i+1:03d}.png")
            
            print(f"Generating {i+1}/{len(face_descriptions)}: {description[:50]}...")
            
            if self.generate_face(description, output_path):
                successful_paths.append(output_path)
            
            # Add delay to avoid rate limiting
            time.sleep(2)
        
        return successful_paths

def create_face_descriptions(num_faces: int) -> List[str]:
    """Create diverse face descriptions."""
    
    # Base characteristics
    ages = ["young adult", "middle-aged", "elderly"]
    genders = ["man", "woman"]
    ethnicities = ["Caucasian", "African American", "Asian", "Hispanic", "Middle Eastern", "South Asian"]
    hair_colors = ["black", "brown", "blonde", "red", "gray", "white"]
    hair_styles = ["short", "medium length", "long", "curly", "straight", "wavy", "bald"]
    eye_colors = ["brown", "blue", "green", "hazel", "gray"]
    facial_features = ["strong jawline", "round face", "oval face", "high cheekbones", "dimples", "freckles"]
    expressions = ["smiling", "serious", "confident", "friendly", "professional"]
    
    descriptions = []
    
    for i in range(num_faces):
        # Randomly select characteristics
        age = random.choice(ages)
        gender = random.choice(genders)
        ethnicity = random.choice(ethnicities)
        hair_color = random.choice(hair_colors)
        hair_style = random.choice(hair_styles)
        eye_color = random.choice(eye_colors)
        facial_feature = random.choice(facial_features)
        expression = random.choice(expressions)
        
        # Create description
        description = f"{age} {ethnicity} {gender}, {hair_color} {hair_style} hair, {eye_color} eyes, {facial_feature}, {expression} expression"
        
        descriptions.append(description)
    
    return descriptions

def create_attribute_test_descriptions() -> List[str]:
    """Create face descriptions for testing each attribute option."""
    
    # Define all attributes with their options
    attributes = {
        "eye_color": ["dark_brown","light_brown","green","blue","gray","amber"],
        "hair_color": ["black","dark_brown","light_brown","blonde","red","gray_white"],
        "age_group": ["18_29","30_44","45_59","60_80"],
        "skin_tone": ["fp_I_very_fair","fp_II_fair","fp_III_medium","fp_IV_olive","fp_V_brown","fp_VI_dark"],
        "face_shape": ["oval","round","square","heart","diamond","oblong"],
        "hair_style": ["bald","buzz_short","short_straight_part","short_curly_coily","medium_wavy","long_curly_locs"],
        "facial_hair": ["none","light_stubble","heavy_stubble","mustache","short_beard","full_beard"],
        "accessories": ["none","eyeglasses_clear","sunglasses","earrings_visible","headwear_cap_hat","scarf_neck_face"],
        "background": ["solid_mono","studio_gradient","indoor_home","office","outdoor_urban","outdoor_nature"]
    }
    
    descriptions = []
    
    # For each attribute, create one description per option for both male and female
    for attr_name, options in attributes.items():
        for option in options:
            # Create descriptions for both male and female
            for gender in ["male", "female"]:
                # Create a neutral base description
                base_description = f"professional headshot, {gender}, neutral expression, good lighting"
                
                # Add the specific attribute being tested
                if attr_name == "eye_color":
                    description = f"{base_description}, {option.replace('_', ' ')} eyes"
                elif attr_name == "hair_color":
                    description = f"{base_description}, {option.replace('_', ' ')} hair"
                elif attr_name == "age_group":
                    age_map = {"18_29": "young adult", "30_44": "middle-aged", "45_59": "mature adult", "60_80": "elderly"}
                    description = f"{base_description}, {age_map[option]}"
                elif attr_name == "skin_tone":
                    tone_map = {
                        "fp_I_very_fair": "very fair skin", "fp_II_fair": "fair skin", "fp_III_medium": "medium skin tone",
                        "fp_IV_olive": "olive skin tone", "fp_V_brown": "brown skin tone", "fp_VI_dark": "dark skin tone"
                    }
                    description = f"{base_description}, {tone_map[option]}"
                elif attr_name == "face_shape":
                    description = f"{base_description}, {option} face shape"
                elif attr_name == "hair_style":
                    style_map = {
                        "bald": "bald head", "buzz_short": "buzz cut", "short_straight_part": "short straight hair with part",
                        "short_curly_coily": "short curly hair", "medium_wavy": "medium wavy hair", "long_curly_locs": "long curly hair"
                    }
                    description = f"{base_description}, {style_map[option]}"
                elif attr_name == "facial_hair":
                    hair_map = {
                        "none": "clean shaven", "light_stubble": "light stubble", "heavy_stubble": "heavy stubble",
                        "mustache": "mustache", "short_beard": "short beard", "full_beard": "full beard"
                    }
                    description = f"{base_description}, {hair_map[option]}"
                elif attr_name == "accessories":
                    acc_map = {
                        "none": "no accessories", "eyeglasses_clear": "clear eyeglasses", "sunglasses": "sunglasses",
                        "earrings_visible": "visible earrings", "headwear_cap_hat": "cap or hat", "scarf_neck_face": "scarf around neck"
                    }
                    description = f"{base_description}, {acc_map[option]}"
                elif attr_name == "background":
                    bg_map = {
                        "solid_mono": "solid monochrome background", "studio_gradient": "studio gradient background",
                        "indoor_home": "indoor home background", "office": "office background",
                        "outdoor_urban": "outdoor urban background", "outdoor_nature": "outdoor nature background"
                    }
                    description = f"{base_description}, {bg_map[option]}"
                
                descriptions.append((f"{attr_name}_{option}_{gender}", description))
    
    return descriptions

def main():
    parser = argparse.ArgumentParser(
        description="Generate face images using OpenRouter API with Gemini 2.5 Flash Image Preview",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Generate 10 random faces
  python generate_faces_test_atributes.py --num-faces 10 --output-dir faces_output
  
  # Generate test images for each attribute option (108 total images: 54 attributes × 2 genders)
  python generate_faces_test_atributes.py --test-attributes --output-dir attribute_tests
  
  # Generate specific faces
  python generate_faces_test_atributes.py --specific --output-dir specific_faces
  
  # Generate with custom descriptions
  python generate_faces_test_atributes.py --descriptions "young woman with red hair" "elderly man with beard" --output-dir custom_faces
        """
    )
    
    parser.add_argument("--num-faces", type=int, default=5,
                       help="Number of random faces to generate (default: 5)")
    parser.add_argument("--output-dir", required=True,
                       help="Output directory for generated images")
    parser.add_argument("--specific", action="store_true",
                       help="Generate specific predefined faces instead of random")
    parser.add_argument("--test-attributes", action="store_true",
                       help="Generate test images for each attribute option")
    parser.add_argument("--descriptions", nargs="+",
                       help="Custom face descriptions to generate")
    parser.add_argument("--api-key", 
                       help="OpenRouter API key (or set OPENROUTER_API_KEY env var)")
    
    args = parser.parse_args()
    
    # Get API key
    api_key = args.api_key or API_KEY_REF
    if api_key == "your-api-key-here":
        print("Error: Please set your OpenRouter API key using --api-key or OPENROUTER_API_KEY environment variable")
        return
    
    # Initialize generator
    generator = FaceGenerator(api_key)
    
    # Determine face descriptions
    if args.descriptions:
        face_descriptions = args.descriptions
    elif args.test_attributes:
        face_descriptions = create_attribute_test_descriptions()
    elif args.specific:
        face_descriptions = create_specific_face_descriptions()
    else:
        face_descriptions = create_face_descriptions(args.num_faces)
    
    print(f"Generating {len(face_descriptions)} face images...")
    print(f"Output directory: {args.output_dir}")
    print()
    
    # Generate faces
    successful_paths = generator.generate_multiple_faces(face_descriptions, args.output_dir)
    
    # Summary
    print(f"\nGeneration complete!")
    print(f"Successfully generated: {len(successful_paths)}/{len(face_descriptions)} faces")
    print(f"Images saved to: {args.output_dir}")
    
    if successful_paths:
        print("\nGenerated images:")
        for path in successful_paths:
            print(f"  - {path}")

if __name__ == "__main__":
    main()
