#!/usr/bin/env python3
"""
Script to regenerate failed images from failed_generations.json file.
Uses the same model and hyperparameters as generate_faces_4_attributes.py.

Usage:
    python regenerate_failed_images.py --failed-file faces4/failed_generations.json --output-dir faces4
"""

import requests
import json
import os
import argparse
import base64
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import time
from dotenv import load_dotenv

# Load environment variables from .env file in the top directory
load_dotenv(os.path.join(os.path.dirname(os.path.dirname(__file__)), '.env'))

# Get OpenRouter API key from environment
API_KEY_REF = os.getenv("OPENROUTER_API_KEY", "your-api-key-here")

class FailedImageRegenerator:
    def __init__(self, api_key: str = API_KEY_REF, output_dir: str = None):
        """Initialize the failed image regenerator with API key."""
        self.api_key = api_key
        self.url = "https://openrouter.ai/api/v1/chat/completions"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        self.output_dir = output_dir
        self.still_failed = []
        self.successfully_regenerated = []
        
    def regenerate_image(self, prompt: str, output_path: str, max_retries: int = 10) -> Tuple[bool, str]:
        """
        Regenerate a single image using the original prompt.
        
        Args:
            prompt: The original prompt used for generation
            output_path: Path to save the generated image
            max_retries: Maximum number of retry attempts
            
        Returns:
            Tuple[bool, str]: (success, error_reason)
        """
        payload = {
            "model": "google/gemini-2.5-flash-image-preview",
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "modalities": ["image", "text"]
        }
        
        last_error = None
        last_response = None
        
        for attempt in range(max_retries):
            try:
                print(f"Attempt {attempt + 1}/{max_retries} for regeneration...")
                
                response = requests.post(self.url, headers=self.headers, json=payload)
                response.raise_for_status()
                result = response.json()
                last_response = result
                
                # Extract image from response
                if result.get("choices"):
                    message = result["choices"][0]["message"]
                    if message.get("images"):
                        for image in message["images"]:
                            image_url = image["image_url"]["url"]
                            
                            # Handle base64 data URL
                            if image_url.startswith("data:image"):
                                # Extract base64 data
                                header, data = image_url.split(",", 1)
                                image_data = base64.b64decode(data)
                                
                                # Ensure output directory exists
                                os.makedirs(os.path.dirname(output_path), exist_ok=True)
                                
                                # Save image
                                with open(output_path, "wb") as f:
                                    f.write(image_data)
                                
                                # Verify image was saved and has content
                                if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
                                    print(f"Successfully regenerated image: {output_path}")
                                    return True, None
                                else:
                                    last_error = f"Image file was not created or is empty: {output_path}"
                                    print(last_error)
                            else:
                                last_error = f"Unexpected image URL format: {image_url[:50]}..."
                                print(last_error)
                    else:
                        last_error = "No images found in response"
                        print(f"{last_error} (attempt {attempt + 1})")
                else:
                    last_error = "No choices found in response"
                    print(f"{last_error} (attempt {attempt + 1})")
                    
            except requests.exceptions.RequestException as e:
                last_error = f"Request failed: {str(e)}"
                print(f"{last_error} (attempt {attempt + 1})")
            except Exception as e:
                last_error = f"Exception: {str(e)}"
                print(f"{last_error} (attempt {attempt + 1})")
            
            # Wait before retry (except on last attempt)
            if attempt < max_retries - 1:
                time.sleep(2)
        
        # All attempts failed
        print(f"All {max_retries} attempts failed for regeneration")
        return False, last_error
    
    def regenerate_from_failed_file(self, failed_file_path: str, max_retries: int = 10):
        """
        Regenerate all images from the failed_generations.json file.
        
        Args:
            failed_file_path: Path to the failed_generations.json file
            max_retries: Maximum number of retry attempts per image
        """
        # Load failed generations
        with open(failed_file_path, 'r') as f:
            failed_data = json.load(f)
        
        print(f"Found {len(failed_data)} failed generations to attempt regeneration...")
        
        for i, failed_item in enumerate(failed_data):
            print(f"\n--- Regenerating {i+1}/{len(failed_data)} ---")
            print(f"Output path: {failed_item['output_path']}")
            print(f"Characteristics: {failed_item['characteristics']}")
            
            # Use the original prompt
            prompt = failed_item['prompt_used']
            
            # Regenerate the image
            success, error_reason = self.regenerate_image(prompt, failed_item['output_path'], max_retries)
            
            if success:
                self.successfully_regenerated.append({
                    "original_failed_item": failed_item,
                    "regenerated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
                    "output_path": failed_item['output_path']
                })
                print(f"✓ Successfully regenerated: {failed_item['output_path']}")
            else:
                # Still failed - add to still_failed list
                still_failed_item = failed_item.copy()
                still_failed_item["regeneration_attempts"] = max_retries
                still_failed_item["last_regeneration_error"] = error_reason
                still_failed_item["last_regeneration_attempt"] = time.strftime("%Y-%m-%d %H:%M:%S")
                self.still_failed.append(still_failed_item)
                print(f"✗ Still failed: {failed_item['output_path']} - {error_reason}")
            
            # Add delay between requests to avoid rate limiting
            time.sleep(1)
    
    def save_results(self, output_dir: str):
        """Save the regeneration results."""
        # Save successfully regenerated images info
        if self.successfully_regenerated:
            success_file = os.path.join(output_dir, "successfully_regenerated.json")
            with open(success_file, 'w') as f:
                json.dump(self.successfully_regenerated, f, indent=2)
            print(f"\nSaved successfully regenerated info to: {success_file}")
        
        # Create failed_generations_updated.json file with only still failed items
        if self.still_failed:
            updated_failed_file = os.path.join(output_dir, "failed_generations_updated.json")
            with open(updated_failed_file, 'w') as f:
                json.dump(self.still_failed, f, indent=2)
            print(f"Created failed_generations_updated.json with {len(self.still_failed)} still failed items")
        else:
            print("All images successfully regenerated! No failed_generations_updated.json needed")
        
        # Print summary
        print(f"\n--- Regeneration Summary ---")
        print(f"Successfully regenerated: {len(self.successfully_regenerated)}")
        print(f"Still failed: {len(self.still_failed)}")
        print(f"Total processed: {len(self.successfully_regenerated) + len(self.still_failed)}")

def main():
    parser = argparse.ArgumentParser(description="Regenerate failed images from failed_generations.json")
    parser.add_argument("--failed-file", required=True, help="Path to failed_generations.json file")
    parser.add_argument("--output-dir", required=True, help="Output directory for images")
    parser.add_argument("--max-retries", type=int, default=10, help="Maximum retry attempts per image (default: 10)")
    parser.add_argument("--limit", type=int, help="Limit number of images to process (for testing)")
    
    args = parser.parse_args()
    
    # Check if failed file exists
    if not os.path.exists(args.failed_file):
        print(f"Error: Failed file not found: {args.failed_file}")
        return
    
    # Initialize regenerator
    regenerator = FailedImageRegenerator(output_dir=args.output_dir)
    
    # Load and limit failed data if requested
    with open(args.failed_file, 'r') as f:
        failed_data = json.load(f)
    
    if args.limit:
        failed_data = failed_data[:args.limit]
        print(f"Limited to first {args.limit} failed images for testing")
    
    # Create temporary file with limited data
    temp_failed_file = args.failed_file + ".temp"
    with open(temp_failed_file, 'w') as f:
        json.dump(failed_data, f, indent=2)
    
    try:
        # Regenerate images
        regenerator.regenerate_from_failed_file(temp_failed_file, args.max_retries)
        
        # Save results
        regenerator.save_results(args.output_dir)
        
    finally:
        # Clean up temporary file
        if os.path.exists(temp_failed_file):
            os.remove(temp_failed_file)

if __name__ == "__main__":
    main()
