#!/usr/bin/env python3
"""
Script to generate face images using OpenRouter API with Gemini 2.5 Flash Image Preview.
Generates faces with 4 specific attributes (eye_color, hair_color, hair_style, facial_hair, accessories)
and saves metadata in JSON format.

Usage:
    python generate_faces_4_attributes.py --output-dir faces_output
"""

import requests
import json
import os
import argparse
import base64
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import time
import random
from dotenv import load_dotenv

# Load environment variables from .env file in the top directory
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

# Get OpenRouter API key from environment
API_KEY_REF = os.getenv("OPENROUTER_API_KEY", "your-api-key-here")

# Secondary attributes for facial diversity
SECONDARY_ATTRIBUTES = {
    "age_groups": ["young_adult", "middle_aged", "elderly"],
    "skin_tones": ["I", "II", "III", "IV", "V", "VI"],
    "face_shapes": ["oval face", "round face", "square face", "heart-shaped face", "diamond face", "long face"],
    "eyebrow_shapes": ["thick eyebrows", "thin eyebrows", "arched eyebrows", "straight eyebrows", "bushy eyebrows", "defined eyebrows"],
    "lip_shapes": ["full lips", "thin lips", "wide lips", "narrow lips", "natural lips", "prominent lips", "lipstick on lips"],
    "facial_features": ["defined cheekbones", "strong jawline", "soft features", "angular features", "prominent features", "distinctive features"]
}

class FaceGenerator:
    def __init__(self, api_key: str = API_KEY_REF, output_dir: str = None):
        """Initialize the face generator with API key."""
        self.api_key = api_key
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        self.output_dir = output_dir
        self.failed_generations = []
        
    def generate_face(self, characteristics: str, output_path: str, max_retries: int = 10) -> Tuple[bool, str]:
        """
        Generate a single face image with given characteristics.
        
        Args:
            characteristics: Description of face characteristics
            output_path: Path to save the generated image
            max_retries: Maximum number of retry attempts
            
        Returns:
            Tuple[bool, str]: (success, full_prompt_used)
        """
        # Create the full prompt that will be sent to the API
        full_prompt = f"Generate a realistic color portrait photo of a person with natural human skin tone and these characteristics: {characteristics}. Make it high quality, professional headshot style, good lighting, clear facial features, full color image with natural skin color and umbioquious eye color. Only the background should be monochromatic gray. IMPORTANT: Make this person look unique and not generic - vary facial structure, bone structure, skin texture, and overall appearance to ensure maximum uniqueness and diversity."
        
        payload = {
            "model": "google/gemini-2.5-flash-image-preview",
            "messages": [
                {
                    "role": "user",
                    "content": full_prompt
                }
            ],
            "modalities": ["image", "text"]
        }
        
        last_error = None
        last_response = None
        
        for attempt in range(max_retries):
            try:
                print(f"Attempt {attempt + 1}/{max_retries} for face generation...")
                
                response = requests.post(self.url, headers=self.headers, json=payload)
                response.raise_for_status()
                result = response.json()
                last_response = result
                
                # Extract image from response
                if result.get("choices"):
                    message = result["choices"][0]["message"]
                    if message.get("images"):
                        for image in message["images"]:
                            image_url = image["image_url"]["url"]
                            
                            # Handle base64 data URL
                            if image_url.startswith("data:image"):
                                # Extract base64 data
                                header, data = image_url.split(",", 1)
                                image_data = base64.b64decode(data)
                                
                                # Save image
                                with open(output_path, "wb") as f:
                                    f.write(image_data)
                                
                                # Verify image was saved and has content
                                if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
                                    print(f"Generated face saved to: {output_path}")
                                    return True, full_prompt
                                else:
                                    last_error = f"Image file was not created or is empty: {output_path}"
                                    print(last_error)
                            else:
                                last_error = f"Unexpected image URL format: {image_url[:50]}..."
                                print(last_error)
                    else:
                        last_error = "No images found in response"
                        print(f"{last_error} (attempt {attempt + 1})")
                else:
                    last_error = "No choices found in response"
                    print(f"{last_error} (attempt {attempt + 1})")
                    
            except requests.exceptions.RequestException as e:
                last_error = f"Request failed: {str(e)}"
                print(f"{last_error} (attempt {attempt + 1})")
            except Exception as e:
                last_error = f"Exception: {str(e)}"
                print(f"{last_error} (attempt {attempt + 1})")
            
            # Wait before retry (except on last attempt)
            if attempt < max_retries - 1:
                time.sleep(2)
        
        # All attempts failed - save the failure info
        print(f"All {max_retries} attempts failed for face generation")
        self._save_failed_generation(characteristics, output_path, payload, last_response, last_error)
        return False, full_prompt
    
    def _save_failed_generation(self, characteristics: str, output_path: str, payload: dict, response: dict, error_reason: str):
        """Save information about failed generation attempts."""
        failed_info = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "characteristics": characteristics,
            "output_path": output_path,
            "prompt_used": payload["messages"][0]["content"],
            "error_reason": error_reason,
            "api_response": response
        }
        self.failed_generations.append(failed_info)
    
    def save_failed_generations_to_file(self, output_dir: str):
        """Save all failed generation attempts to a JSON file."""
        if self.failed_generations:
            failed_file = os.path.join(output_dir, "failed_generations.json")
            with open(failed_file, 'w') as f:
                json.dump(self.failed_generations, f, indent=2)
            print(f"Saved {len(self.failed_generations)} failed generation attempts to: {failed_file}")
    
    def generate_multiple_faces(self, face_descriptions: List[Tuple[str, str, Dict]], output_dir: str) -> List[Dict]:
        """
        Generate multiple face images with metadata.
        
        Args:
            face_descriptions: List of (name, description, attributes) tuples
            output_dir: Directory to save images
            
        Returns:
            List of successfully generated image metadata
        """
        os.makedirs(output_dir, exist_ok=True)
        successful_metadata = []
        
        for i, (name, description, attributes) in enumerate(face_descriptions):
            output_path = os.path.join(output_dir, f"{name}.png")
            
            print(f"Generating {i+1}/{len(face_descriptions)}: {description[:50]}...")
            
            success, full_prompt = self.generate_face(description, output_path)
            if success:
                # Create metadata entry
                metadata = {
                    "image_name": f"{name}.png",
                    "image_path": output_path,
                    "description": description,
                    "full_prompt": full_prompt,
                    "attributes": attributes
                }
                successful_metadata.append(metadata)
            else:
                print(f"❌ FAILED after all attempts: {name}.png")
                print(f"   Primary Attributes: eye_color: {attributes['eye_color']}, hair_color: {attributes['hair_color']}, hair_style: {attributes['hair_style']}, accessories: {attributes['accessories']}, gender: {attributes['gender']}")
                sec_attrs = attributes['secondary_attributes']
                print(f"   Secondary Attributes: age_group: {sec_attrs['age_group']}, skin_tone: {sec_attrs['skin_tone']}, face_shape: {sec_attrs['face_shape']}, eyebrow_shape: {sec_attrs['eyebrow_shape']}, lip_shape: {sec_attrs['lip_shape']}, facial_features: {sec_attrs['facial_features']}")
            
            # Add delay to avoid rate limiting
            time.sleep(2)
        
        return successful_metadata

def create_face_descriptions() -> Tuple[List[Tuple[str, str, Dict]], Dict]:
    """Create face descriptions for all attribute combinations."""
    
    # Define the 4 attributes with 6 options each
    attributes = {
        "eye_color": ["blue", "dark_brown", "gray", "red"],
        "hair_color": ["black", "light_brown", "blonde", "red", "gray_white", "blue"],
        "hair_style": ["shoulder_straight", "shoulder_afro", "long_wavy", "long_straight", "buzz_cut"],
        "accessories": ["none", "eyeglasses_clear", "earrings_visible", "headband", "scarf_neck_face"]
    }
    
    # Note: accessories has fewer than 6 options, so we'll use all combinations
    descriptions = []
    image_number = 1
    
    # Test mode: limit number of images
    test_mode = True  # Set to False to generate all combinations
    max_test_images = 10  # Number of images to generate in test mode
    
    # Generate all combinations
    for eye_color in attributes["eye_color"]:
        for hair_color in attributes["hair_color"]:
            for hair_style in attributes["hair_style"]:
                for accessory in attributes["accessories"]:
                    # Stop if we've reached the test limit
                    if test_mode and image_number > max_test_images:
                        break
                    
                    # Randomly assign gender (50% male, 50% female)
                    gender = random.choice(["male", "female"])
                    
                    # Create name using sequential number
                    name = f"face_{image_number:04d}"
                    image_number += 1
                    
                    # Create description
                    description_parts = []
                    
                    # Add gender
                    description_parts.append(gender)
                    
                    # Eye color
                    if eye_color == "dark_brown":
                        description_parts.append("dark brown eyes")
                    elif eye_color == "light_brown":
                        description_parts.append("light brown eyes")
                    elif eye_color == "gray":
                        description_parts.append("gray eyes")
                    elif eye_color == "blue":
                        description_parts.append("very saturated and vivid blue eyes, not ambigous (not gray or green, very obvious blue)")
                    else:
                        description_parts.append(f"{eye_color} eyes")
                    
                    # Hair color
                    if hair_color == "light_brown":
                        description_parts.append("light brown hair")
                    elif hair_color == "gray_white":
                        description_parts.append("gray or white hair")
                    elif hair_color == "red":
                        description_parts.append("vibrant red hair")
                    elif hair_color == "blue":
                        description_parts.append("blue hair")
                    elif hair_color == "blonde":
                        description_parts.append("blonde/yellow hair")
                    else:
                        description_parts.append(f"{hair_color} hair")
                    
                    # Hair style
                    if hair_style == "shoulder_straight":
                        description_parts.append("shoulder length straight hair, clearly above shoulders, with clear gap between hair and shoulders")
                    elif hair_style == "shoulder_wavy":
                        description_parts.append("shoulder length wavy hair, clearly above shoulders, with clear gap between hair and shoulders")
                    elif hair_style == "shoulder_afro":
                        description_parts.append("shoulder length afro style very curly, clearly above shoulders")
                    elif hair_style == "long_wavy":
                        description_parts.append("long wavy hair but not afro style")
                    elif hair_style == "long_straight":
                        description_parts.append("long straight hair")
                    elif hair_style == "buzz_cut":
                        description_parts.append("buzz cut with straight hair")
                    
                    # Accessories (only one accessory should be present)
                    if accessory == "eyeglasses_clear":
                        description_parts.append("clear eyeglasses, no other accessories")
                    elif accessory == "earrings_visible":
                        description_parts.append("quite visible earrings, no other accessories")
                    elif accessory == "headband":
                        description_parts.append("vivid bright colored headband (e.g. pink), no other accessories")
                    elif accessory == "scarf_neck_face":
                        description_parts.append("scarf around neck, no other accessories")
                    elif accessory == "none":
                        description_parts.append("no accessories at all")
                    
                    # Randomly select one from each category
                    random_face_shape = random.choice(SECONDARY_ATTRIBUTES["face_shapes"])
                    random_eyebrow_shape = random.choice(SECONDARY_ATTRIBUTES["eyebrow_shapes"])
                    random_lip_shape = random.choice(SECONDARY_ATTRIBUTES["lip_shapes"])
                    random_facial_feature = random.choice(SECONDARY_ATTRIBUTES["facial_features"])
                    
                    # Randomly select secondary attributes
                    age_group = random.choice(SECONDARY_ATTRIBUTES["age_groups"])
                    skin_tone = random.choice(SECONDARY_ATTRIBUTES["skin_tones"])
                    
                    # Add age group and skin tone to description
                    if age_group == "young_adult":
                        age_description = "young adult"
                    elif age_group == "middle_aged":
                        age_description = "middle-aged"
                    elif age_group == "elderly":
                        age_description = "elderly"
                    
                    # Add skin tone description
                    skin_tone_description = f"{skin_tone} skin type on Fitzpatrick scale"
                    
                    # Combine into full description
                    description = f"professional headshot, neutral expression, good lighting, monochromatic gray background, {random_face_shape}, {random_eyebrow_shape}, {random_lip_shape}, {random_facial_feature}, {age_description}, {skin_tone_description}, " + ", ".join(description_parts)
                    
                    # Create attributes dict
                    attr_dict = {
                        "gender": gender,
                        "eye_color": eye_color,
                        "hair_color": hair_color,
                        "hair_style": hair_style,
                        "accessories": accessory,
                        "secondary_attributes": {
                            "age_group": age_group,
                            "skin_tone": skin_tone,
                            "face_shape": random_face_shape,
                            "eyebrow_shape": random_eyebrow_shape,
                            "lip_shape": random_lip_shape,
                            "facial_features": random_facial_feature
                        }
                    }
                    
                    descriptions.append((name, description, attr_dict))
                    
                    # Break out of accessory loop if test limit reached
                    if test_mode and image_number > max_test_images:
                        break
                
                # Break out of hair_style loop if test limit reached
                if test_mode and image_number > max_test_images:
                    break
            
            # Break out of hair_color loop if test limit reached
            if test_mode and image_number > max_test_images:
                break
        
        # Break out of eye_color loop if test limit reached
        if test_mode and image_number > max_test_images:
            break
    
    return descriptions, attributes

def save_metadata(metadata_list: List[Dict], output_dir: str, attributes_info: Dict):
    """Save metadata to JSON file with attributes summary."""
    metadata_path = os.path.join(output_dir, "face_metadata.json")
    
    # Create the complete metadata structure
    complete_metadata = {
        "attributes_summary": {
            "primary_attributes": {
                "eye_color": {
                    "options": attributes_info["eye_color"],
                    "count": len(attributes_info["eye_color"])
                },
                "hair_color": {
                    "options": attributes_info["hair_color"],
                    "count": len(attributes_info["hair_color"])
                },
                "hair_style": {
                    "options": attributes_info["hair_style"],
                    "count": len(attributes_info["hair_style"])
                },
                "accessories": {
                    "options": attributes_info["accessories"],
                    "count": len(attributes_info["accessories"])
                }
            },
            "secondary_attributes": {
                "age_group": {
                    "options": SECONDARY_ATTRIBUTES["age_groups"],
                    "count": len(SECONDARY_ATTRIBUTES["age_groups"])
                },
                "skin_tone": {
                    "options": SECONDARY_ATTRIBUTES["skin_tones"],
                    "count": len(SECONDARY_ATTRIBUTES["skin_tones"])
                },
                "face_shape": {
                    "options": SECONDARY_ATTRIBUTES["face_shapes"],
                    "count": len(SECONDARY_ATTRIBUTES["face_shapes"])
                },
                "eyebrow_shape": {
                    "options": SECONDARY_ATTRIBUTES["eyebrow_shapes"],
                    "count": len(SECONDARY_ATTRIBUTES["eyebrow_shapes"])
                },
                "lip_shape": {
                    "options": SECONDARY_ATTRIBUTES["lip_shapes"],
                    "count": len(SECONDARY_ATTRIBUTES["lip_shapes"])
                },
                "facial_features": {
                    "options": SECONDARY_ATTRIBUTES["facial_features"],
                    "count": len(SECONDARY_ATTRIBUTES["facial_features"])
                }
            },
            "total_combinations": {
                "primary": len(attributes_info["eye_color"]) * len(attributes_info["hair_color"]) * len(attributes_info["hair_style"]) * len(attributes_info["accessories"]),
                "secondary": len(SECONDARY_ATTRIBUTES["age_groups"]) * len(SECONDARY_ATTRIBUTES["skin_tones"]) * len(SECONDARY_ATTRIBUTES["face_shapes"]) * len(SECONDARY_ATTRIBUTES["eyebrow_shapes"]) * len(SECONDARY_ATTRIBUTES["lip_shapes"]) * len(SECONDARY_ATTRIBUTES["facial_features"]),
                "total": len(attributes_info["eye_color"]) * len(attributes_info["hair_color"]) * len(attributes_info["hair_style"]) * len(attributes_info["accessories"]) * len(SECONDARY_ATTRIBUTES["age_groups"]) * len(SECONDARY_ATTRIBUTES["skin_tones"]) * len(SECONDARY_ATTRIBUTES["face_shapes"]) * len(SECONDARY_ATTRIBUTES["eyebrow_shapes"]) * len(SECONDARY_ATTRIBUTES["lip_shapes"]) * len(SECONDARY_ATTRIBUTES["facial_features"])
            }
        },
        "generated_faces": metadata_list
    }
    
    with open(metadata_path, 'w') as f:
        json.dump(complete_metadata, f, indent=2)
    
    print(f"Metadata saved to: {metadata_path}")

def main():
    parser = argparse.ArgumentParser(
        description="Generate face images with 4 specific attributes and save metadata",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Generate all attribute combinations (720 total images)
  python generate_faces_4_attributes.py --output-dir faces_4_attributes
  
  # Generate with custom API key
  python generate_faces_4_attributes.py --output-dir faces_4_attributes --api-key your-key
        """
    )
    
    parser.add_argument("--output-dir", required=True,
                       help="Output directory for generated images")
    parser.add_argument("--api-key", 
                       help="OpenRouter API key (or set OPENROUTER_API_KEY env var)")
    
    args = parser.parse_args()
    
    # Get API key
    api_key = args.api_key or API_KEY_REF
    if api_key == "your-api-key-here":
        print("Error: Please set your OpenRouter API key using --api-key or OPENROUTER_API_KEY environment variable")
        return
    
    # Initialize generator
    generator = FaceGenerator(api_key, args.output_dir)
    
    # Create face descriptions
    face_descriptions, attributes = create_face_descriptions()
    
    # Calculate totals dynamically
    eye_color_count = len(attributes["eye_color"])
    hair_color_count = len(attributes["hair_color"])
    hair_style_count = len(attributes["hair_style"])
    accessories_count = len(attributes["accessories"])
    total_combinations = eye_color_count * hair_color_count * hair_style_count * accessories_count
    
    print(f"Generating {len(face_descriptions)} face images...")
    print(f"Output directory: {args.output_dir}")
    print(f"Attributes: eye_color ({eye_color_count}), hair_color ({hair_color_count}), hair_style ({hair_style_count}), accessories ({accessories_count})")
    print(f"Gender: 50% male, 50% female (randomly assigned)")
    print(f"Total combinations: {eye_color_count} × {hair_color_count} × {hair_style_count} × {accessories_count} = {total_combinations}")
    print()
    
    # Generate faces
    successful_metadata = generator.generate_multiple_faces(face_descriptions, args.output_dir)
    
    # Save metadata
    save_metadata(successful_metadata, args.output_dir, attributes)
    
    # Save failed generations info
    generator.save_failed_generations_to_file(args.output_dir)
    
    # Summary
    print(f"\nGeneration complete!")
    print(f"Successfully generated: {len(successful_metadata)}/{len(face_descriptions)} faces")
    print(f"Images saved to: {args.output_dir}")
    print(f"Metadata saved to: {os.path.join(args.output_dir, 'face_metadata.json')}")

if __name__ == "__main__":
    main()
