#!/usr/bin/env python3
"""
Script to generate headshot variations of existing face images with different backgrounds.
Takes face images from a directory and generates 3 more headshot images of each person with:
- Office background
- Nature background  
- Urban background

Usage:
    python generate_person_variations.py --input-dir faces4 --output-dir person_variations
"""

import requests
import json
import os
import argparse
import base64
from pathlib import Path
from typing import List, Dict, Optional
import time
import random
from dotenv import load_dotenv

# Load environment variables from .env file in the top directory
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

# Get OpenRouter API key from environment
API_KEY_REF = os.getenv("OPENROUTER_API_KEY", "your-api-key-here")

class PersonVariationGenerator:
    def __init__(self, api_key: str = API_KEY_REF, output_dir: str = None):
        """Initialize the person variation generator with API key."""
        self.api_key = api_key
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        self.output_dir = output_dir
        self.failed_generations = []
        self.processed_people = []
        
    def encode_image_to_base64(self, image_path: str) -> str:
        """Encode image file to base64 string."""
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
    
    def generate_person_variation(self, original_image_path: str, setting: str, output_path: str, max_retries: int = 10) -> bool:
        """
        Generate a variation of a person in a specific setting with retry logic.
        
        Args:
            original_image_path: Path to the original face image
            setting: The setting description (office, nature, urban)
            output_path: Path to save the generated image
            max_retries: Maximum number of retry attempts
            
        Returns:
            bool: True if successful, False otherwise
        """
        # Encode the original image
        base64_image = self.encode_image_to_base64(original_image_path)
        
        # Create setting-specific prompts for headshots with different backgrounds
        setting_prompts = {
            "office": "Generate a professional headshot of this exact same person with an office background. Keep the same facial features, hair color, hair style, eye color, and overall appearance as the original image. IMPORTANT: Keep the exact same accessories as the original - if the person has earrings, they must have earrings; if they have a headband, they must have the same headband; if they have glasses, they must have the same glasses; if they have no accessories, they must have no accessories. Don't add accessories if the person didn't have any (e.g., if there are no earrings in the original image, don't add them). The headshot can be in different positions (slight angle, different head tilt) and can show different amounts of shoulders - it doesn't have to match the exact same framing as the original. This should be a headshot/portrait style, not a full-body image.",
            "nature": "Generate a professional headshot of this exact same person with a nature background. Keep the same facial features, hair color, hair style, eye color, and overall appearance as the original image. The background should be a natural outdoor environment. IMPORTANT: Keep the exact same accessories as the original - if the person has earrings, they must have earrings; if they have a headband, they must have the same headband; if they have glasses, they must have the same glasses; if they have no accessories, they must have no accessories. Don't add accessories if the person didn't have any (e.g., if there are no earrings in the original image, don't add them). The headshot can be in different positions (slight angle, different head tilt) and can show different amounts of shoulders - it doesn't have to match the exact same framing as the original. This should be a headshot/portrait style, not a full-body image.",
            "urban": "Generate a professional headshot of this exact same person with an urban background. Keep the same facial features, hair color, hair style, eye color, and overall appearance as the original image. The background should be a city environment. IMPORTANT: Keep the exact same accessories as the original - if the person has earrings, they must have earrings; if they have a headband, they must have the same headband; if they have glasses, they must have the same glasses; if they have no accessories, they must have no accessories. Don't add accessories if the person didn't have any (e.g., if there are no earrings in the original image, don't add them). The headshot can be in different positions (slight angle, different head tilt) and can show different amounts of shoulders - it doesn't have to match the exact same framing as the original. This should be a headshot/portrait style, not a full-body image."
        }
        
        payload = {
            "model": "google/gemini-2.5-flash-image-preview",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": setting_prompts[setting]
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            }
                        }
                    ]
                }
            ],
            "modalities": ["image", "text"]
        }
        
        last_error = None
        last_response = None
        
        for attempt in range(max_retries):
            try:
                print(f"Attempt {attempt + 1}/{max_retries} for {setting} setting...")
                
                response = requests.post(self.url, headers=self.headers, json=payload)
                response.raise_for_status()
                result = response.json()
                last_response = result
                
                # Extract image from response
                if result.get("choices"):
                    message = result["choices"][0]["message"]
                    if message.get("images"):
                        for image in message["images"]:
                            image_url = image["image_url"]["url"]
                            
                            # Handle base64 data URL
                            if image_url.startswith("data:image"):
                                # Extract base64 data
                                header, data = image_url.split(",", 1)
                                image_data = base64.b64decode(data)
                                
                                # Save image
                                with open(output_path, "wb") as f:
                                    f.write(image_data)
                                
                                print(f"Generated {setting} variation saved to: {output_path}")
                                return True
                            else:
                                last_error = f"Unexpected image URL format: {image_url[:50]}..."
                                print(last_error)
                    else:
                        last_error = "No images found in response"
                        print(f"{last_error} for {setting} setting (attempt {attempt + 1})")
                else:
                    last_error = "No choices found in response"
                    print(f"{last_error} for {setting} setting (attempt {attempt + 1})")
                    
            except requests.exceptions.RequestException as e:
                last_error = f"Request failed: {str(e)}"
                print(f"{last_error} for {setting} setting (attempt {attempt + 1})")
            except Exception as e:
                last_error = f"Exception: {str(e)}"
                print(f"{last_error} for {setting} setting (attempt {attempt + 1})")
            
            # Wait before retry (except on last attempt)
            if attempt < max_retries - 1:
                time.sleep(2)
        
        # All attempts failed - save the failure info
        print(f"All {max_retries} attempts failed for {setting} setting")
        self._save_failed_generation(original_image_path, setting, setting_prompts[setting], last_response, last_error, output_path)
        return False
    
    def _get_attributes_from_metadata(self, original_image_path: str) -> dict:
        """Get attributes for the original image from metadata file."""
        try:
            # Look for metadata file in the same directory as the image
            image_dir = os.path.dirname(original_image_path)
            metadata_file = os.path.join(image_dir, "face_metadata.json")
            
            if os.path.exists(metadata_file):
                with open(metadata_file, 'r') as f:
                    metadata = json.load(f)
                
                # Find the entry for this image
                image_name = os.path.basename(original_image_path)
                for entry in metadata:
                    if entry.get("image_name") == image_name:
                        return entry.get("attributes", {})
            
            return {}
        except Exception as e:
            print(f"Warning: Could not load attributes for {original_image_path}: {e}")
            return {}
    
    def _save_failed_generation(self, original_image_path: str, setting: str, prompt: str, response: dict, error_reason: str, output_path: str):
        """Save information about failed generation attempts."""
        # Get attributes for this image
        attributes = self._get_attributes_from_metadata(original_image_path)
        
        failed_info = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "original_image": original_image_path,
            "setting": setting,
            "prompt_used": prompt,
            "error_type": "api_generation_failure",
            "error_reason": error_reason,
            "api_response": response,
            "decided_attributes": attributes,
            "output_path": output_path
        }
        self.failed_generations.append(failed_info)
    
    def _save_processing_failure(self, person_id: str, input_path: str, error_message: str):
        """Save information about processing failures (file system errors, etc.)."""
        failed_info = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "person_id": person_id,
            "input_path": input_path,
            "error_type": "processing_failure",
            "error_reason": error_message,
            "api_response": None
        }
        self.failed_generations.append(failed_info)
    
    def save_failed_generations_to_file(self, output_dir: str):
        """Save all failed generation attempts to a JSON file."""
        if self.failed_generations:
            failed_file = os.path.join(output_dir, "failed_generations.json")
            with open(failed_file, 'w') as f:
                json.dump(self.failed_generations, f, indent=2)
            print(f"Saved {len(self.failed_generations)} failed generation attempts to: {failed_file}")
    
    def save_summary_metadata(self, output_dir: str):
        """Save summary metadata for all processed people."""
        if self.processed_people:
            summary_metadata = {
                "processing_summary": {
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                    "total_people_processed": len(self.processed_people),
                    "total_variations_attempted": sum(person["total_variations"] for person in self.processed_people),
                    "total_variations_successful": sum(person["successful_variations"] for person in self.processed_people),
                    "total_variations_failed": sum(person["total_variations"] - person["successful_variations"] for person in self.processed_people),
                    "success_rate": round(sum(person["successful_variations"] for person in self.processed_people) / sum(person["total_variations"] for person in self.processed_people) * 100, 2) if self.processed_people else 0
                },
                "processed_people": self.processed_people
            }
            
            summary_file = os.path.join(output_dir, "processing_summary.json")
            with open(summary_file, 'w') as f:
                json.dump(summary_metadata, f, indent=2)
            print(f"Saved processing summary to: {summary_file}")
    
    def generate_all_variations(self, original_image_path: str, person_id: str, output_dir: str) -> Dict[str, str]:
        """
        Generate all 3 variations of a person in different settings.
        
        Args:
            original_image_path: Path to the original face image
            person_id: ID of the person (e.g., "0894")
            output_dir: Base output directory
            
        Returns:
            Dict mapping setting names to output paths
        """
        # Create person-specific directory
        person_dir = os.path.join(output_dir, person_id)
        os.makedirs(person_dir, exist_ok=True)
        
        # Copy original image to person directory
        original_filename = os.path.basename(original_image_path)
        original_output_path = os.path.join(person_dir, original_filename)
        
        # Copy original file
        import shutil
        shutil.copy2(original_image_path, original_output_path)
        print(f"Original image copied to: {original_output_path}")
        
        # Generate variations
        settings = ["office", "nature", "urban"]
        variation_paths = {"original": original_output_path}
        generation_results = {"original": True}  # Original is always successful (copied)
        
        for setting in settings:
            output_path = os.path.join(person_dir, f"{person_id}_{setting}.png")
            
            print(f"Generating {setting} variation for person {person_id}...")
            
            if self.generate_person_variation(original_image_path, setting, output_path):
                variation_paths[setting] = output_path
                generation_results[setting] = True
            else:
                print(f"Failed to generate {setting} variation")
                generation_results[setting] = False
            
            # Add delay between requests
            time.sleep(3)
        
        # Save metadata for this person
        self._save_person_metadata(person_id, original_image_path, variation_paths, generation_results, person_dir)
        
        # Track this person for summary
        self.processed_people.append({
            "person_id": person_id,
            "original_image_path": original_image_path,
            "successful_variations": sum(generation_results.values()),
            "total_variations": len(generation_results),
            "metadata_file": os.path.join(person_dir, f"{person_id}_metadata.json")
        })
        
        return variation_paths
    
    def _save_person_metadata(self, person_id: str, original_image_path: str, variation_paths: Dict[str, str], generation_results: Dict[str, bool], person_dir: str):
        """Save metadata for a person's variations."""
        # Get attributes from original image metadata if available
        original_attributes = self._get_attributes_from_metadata(original_image_path)
        
        metadata = {
            "person_id": person_id,
            "original_image_path": original_image_path,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "original_attributes": original_attributes,
            "variations": {
                "original": {
                    "path": variation_paths.get("original"),
                    "success": generation_results.get("original", False)
                },
                "office": {
                    "path": variation_paths.get("office"),
                    "success": generation_results.get("office", False)
                },
                "nature": {
                    "path": variation_paths.get("nature"),
                    "success": generation_results.get("nature", False)
                },
                "urban": {
                    "path": variation_paths.get("urban"),
                    "success": generation_results.get("urban", False)
                }
            },
            "summary": {
                "total_variations": 4,  # original + 3 settings
                "successful_variations": sum(generation_results.values()),
                "failed_variations": len(generation_results) - sum(generation_results.values())
            }
        }
        
        # Save metadata file in person's directory
        metadata_path = os.path.join(person_dir, f"{person_id}_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"Metadata saved to: {metadata_path}")

def main():
    parser = argparse.ArgumentParser(
        description="Generate variations of existing face images in different settings",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Generate headshot variations for all people in directory
  python generate_person_variations.py --input-dir faces4 --output-dir person_variations
  
  # Generate variations for a specific person
  python generate_person_variations.py --input-image faces4/face_0894.png --output-dir person_variations
  
  # Generate variations for only the first 5 people (useful for testing)
  python generate_person_variations.py --input-dir faces4 --output-dir person_variations --limit 5
  
  # Start processing from a specific ID (e.g., from face_0276.png onwards)
  python generate_person_variations.py --input-dir faces4 --output-dir person_variations --start-from 276
        """
    )
    
    parser.add_argument("--input-image", 
                       help="Path to the original face image")
    parser.add_argument("--input-dir",
                       help="Directory containing face images to process")
    parser.add_argument("--output-dir", required=True,
                       help="Output directory for generated variations")
    parser.add_argument("--limit", type=int,
                       help="Limit the number of images to process (useful for testing)")
    parser.add_argument("--start-from", type=int,
                       help="Start processing from this ID number (e.g., 276 to start from face_0276.png)")
    parser.add_argument("--api-key", 
                       help="OpenRouter API key (or set OPENROUTER_API_KEY env var)")
    
    args = parser.parse_args()
    
    # Validate arguments
    if not args.input_image and not args.input_dir:
        print("Error: Please provide either --input-image or --input-dir")
        return
    
    if args.input_image and args.input_dir:
        print("Error: Please provide either --input-image or --input-dir, not both")
        return
    
    # Get API key
    api_key = args.api_key or API_KEY_REF
    if api_key == "your-api-key-here":
        print("Error: Please set your OpenRouter API key using --api-key or OPENROUTER_API_KEY environment variable")
        return
    
    # Initialize generator
    generator = PersonVariationGenerator(api_key, args.output_dir)
    
    # Process images
    if args.input_image:
        # Single image
        if not os.path.exists(args.input_image):
            print(f"Error: Input image not found: {args.input_image}")
            return
        
        # Extract person ID from filename (e.g., "face_0894.png" -> "0894")
        filename = os.path.basename(args.input_image)
        person_id = filename.replace("face_", "").replace(".png", "")
        
        print(f"Processing single image: {args.input_image}")
        print(f"Person ID: {person_id}")
        print(f"Output directory: {args.output_dir}")
        print()
        
        variation_paths = generator.generate_all_variations(args.input_image, person_id, args.output_dir)
        
        print(f"\nGeneration complete for person {person_id}!")
        print("Generated files:")
        for setting, path in variation_paths.items():
            print(f"  - {setting}: {path}")
        
        # Save failed generations info
        generator.save_failed_generations_to_file(args.output_dir)
        
        # Save summary metadata
        generator.save_summary_metadata(args.output_dir)
    
    else:
        # Multiple images from directory
        if not os.path.exists(args.input_dir):
            print(f"Error: Input directory not found: {args.input_dir}")
            return
        
        # Find all PNG files in the directory and sort them numerically
        png_files = [f for f in os.listdir(args.input_dir) if f.endswith('.png') and f.startswith('face_')]
        
        if not png_files:
            print(f"No face images found in {args.input_dir}")
            return
        
        # Sort files numerically by extracting the number from filename (e.g., face_0278.png -> 278)
        def extract_number(filename):
            # Extract number from filename like "face_0278.png"
            try:
                return int(filename.replace('face_', '').replace('.png', ''))
            except ValueError:
                return 0
        
        png_files.sort(key=extract_number)
        
        # Filter files based on start-from ID if specified
        if args.start_from:
            png_files = [f for f in png_files if extract_number(f) >= args.start_from]
            print(f"Starting from ID {args.start_from}, found {len(png_files)} face images to process")
        
        # Apply limit if specified
        if args.limit:
            png_files = png_files[:args.limit]
            print(f"Limited to {args.limit} images, processing {len(png_files)} face images")
        elif not args.start_from:
            print(f"Found {len(png_files)} face images to process")
        
        print(f"Output directory: {args.output_dir}")
        print()
        
        successful_count = 0
        
        for png_file in png_files:
            input_path = os.path.join(args.input_dir, png_file)
            person_id = png_file.replace("face_", "").replace(".png", "")
            
            print(f"Processing {png_file} (Person ID: {person_id})...")
            
            try:
                variation_paths = generator.generate_all_variations(input_path, person_id, args.output_dir)
                successful_count += 1
                print(f"✓ Successfully processed person {person_id}")
            except Exception as e:
                print(f"✗ Failed to process person {person_id}: {e}")
                # Log this failure to the failed generations
                generator._save_processing_failure(person_id, input_path, str(e))
            
            print()
        
        print(f"Batch processing complete!")
        print(f"Successfully processed: {successful_count}/{len(png_files)} people")
        
        # Save failed generations info
        generator.save_failed_generations_to_file(args.output_dir)
        
        # Save summary metadata
        generator.save_summary_metadata(args.output_dir)

if __name__ == "__main__":
    main()
