#!/usr/bin/env python3
"""
FLUX Image Generator Script

This script uses the FLUX.1-schnell model from Black Forest Labs to generate images from text prompts.
It provides a user-friendly interface with command-line arguments and interactive features.
"""

from huggingface_hub import login
login(token="xxx")

import argparse
import os
import sys
import torch
from datetime import datetime
from pathlib import Path

try:
    from diffusers import FluxPipeline
except ImportError:
    print("Error: diffusers library not found. Please install it with:")
    print("pip install diffusers torch transformers")
    sys.exit(1)


class FluxImageGenerator:
    def __init__(self, model_name="black-forest-labs/FLUX.1-schnell", use_cpu_offload=False, 
                 enable_memory_efficient=True):
        """Initialize the FLUX image generator."""
        self.model_name = model_name
        self.use_cpu_offload = use_cpu_offload
        self.enable_memory_efficient = enable_memory_efficient
        self.pipe = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Interactive mode settings
        self.current_guidance_scale = 0.0
        self.current_steps = 4
        self.current_width = 224
        self.current_height = 224
        
        print(f"Using device: {self.device}")
        print(f"Loading model: {model_name}")
        
        # Check available memory
        if self.device == "cuda":
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print(f"GPU Memory: {gpu_memory:.1f} GB")

        self.load_model()
        
    def load_model(self):
        """Load the FLUX model with memory optimizations."""
        try:
            # Clear any cached models
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            # Load the pipeline with memory optimizations
            load_kwargs = {
                "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
                "low_cpu_mem_usage": True,
            }
            
            self.pipe = FluxPipeline.from_pretrained(self.model_name, **load_kwargs)
            
            # Apply memory optimizations
            if self.device == "cuda":
                if self.use_cpu_offload:
                    self.pipe.enable_model_cpu_offload()
                    print("✓ Model CPU offload enabled")
                else:
                    self.pipe = self.pipe.to("cuda")
                        
            print("Model loaded successfully!")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print("\nTry running with different settings for lower memory usage")
            sys.exit(1)
    
    def generate_image(self, prompt, output_path=None, guidance_scale=0.0, 
                      num_inference_steps=4, max_sequence_length=256, seed=None,
                      width=224, height=224):
        """Generate an image from a text prompt."""
        if self.pipe is None:
            self.load_model()
        
        # Set up the generator with seed for reproducibility
        generator = None
        if seed is not None:
            generator = torch.Generator(self.device).manual_seed(seed)
        
        # Generate output filename if not provided
        if output_path is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_path = f"flux_generated_{timestamp}.png"
        
        # Create output directory if it doesn't exist
        output_dir = Path(output_path).parent
        output_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"Generating image with prompt: '{prompt}'")
        print(f"Parameters: steps={num_inference_steps}, guidance_scale={guidance_scale}")
        print(f"Image size: {width}x{height}")
        if seed is not None:
            print(f"Seed: {seed}")
        
        try:
            # Clear cache before generation
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            # Generate the image
            image = self.pipe(
                prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                max_sequence_length=max_sequence_length,
                generator=generator,
                width=width,
                height=height,
            ).images[0]
            
            # Save the image
            image.save(output_path)
            print(f"Image saved to: {output_path}")
            
            # Clear cache after generation
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            return image, output_path
            
        except Exception as e:
            print(f"Error generating image: {e}")
            if "out of memory" in str(e).lower():
                print("\n💡 Memory optimization suggestions:")
                print("   - Reduce image size with smaller --width and --height")
                print("   - Close other applications to free memory")
            return None, None

    def interactive_mode(self, args):
        """Interactive mode with changeable parameters."""
        print("=== FLUX Image Generator - Interactive Mode ===")
        print("Commands:")
        print("  help - Show available commands")
        print("  guidance <value> - Set guidance scale (e.g., 'guidance 3.5')")
        print("  steps <value> - Set inference steps (e.g., 'steps 8')")
        print("  size <width> <height> - Set image size (e.g., 'size 512 512')")
        print("  settings - Show current settings")
        print("  quit/exit - Exit the program")
        print("  <prompt> - Generate image from prompt")
        print()
        
        # Initialize current settings
        self.current_guidance_scale = args.guidance
        self.current_steps = args.steps
        self.current_width = args.width
        self.current_height = args.height
        
        self.show_current_settings()
        
        while True:
            try:
                user_input = input("\n> ").strip()
                
                if not user_input:
                    continue
                
                # Parse commands
                parts = user_input.split()
                command = parts[0].lower()
                
                if command in ['quit', 'exit']:
                    print("Goodbye!")
                    break
                
                elif command == 'help':
                    print("\nAvailable commands:")
                    print("  help - Show this help")
                    print("  guidance <value> - Set guidance scale (e.g., 'guidance 3.5')")
                    print("  steps <value> - Set inference steps (e.g., 'steps 8')")
                    print("  size <width> <height> - Set image size (e.g., 'size 512 512')")
                    print("  settings - Show current settings")
                    print("  quit/exit - Exit the program")
                    print("  <any other text> - Generate image from prompt")
                    continue
                
                elif command == 'guidance':
                    if len(parts) >= 2:
                        try:
                            new_guidance = float(parts[1])
                            self.current_guidance_scale = new_guidance
                            print(f"✓ Guidance scale set to: {new_guidance}")
                        except ValueError:
                            print("❌ Invalid guidance value. Please use a number (e.g., 'guidance 3.5')")
                    else:
                        print("❌ Please specify a guidance value (e.g., 'guidance 3.5')")
                    continue
                
                elif command == 'steps':
                    if len(parts) >= 2:
                        try:
                            new_steps = int(parts[1])
                            if new_steps > 0:
                                self.current_steps = new_steps
                                print(f"✓ Inference steps set to: {new_steps}")
                            else:
                                print("❌ Steps must be a positive integer")
                        except ValueError:
                            print("❌ Invalid steps value. Please use an integer (e.g., 'steps 8')")
                    else:
                        print("❌ Please specify number of steps (e.g., 'steps 8')")
                    continue
                
                elif command == 'size':
                    if len(parts) >= 3:
                        try:
                            new_width = int(parts[1])
                            new_height = int(parts[2])
                            if new_width > 0 and new_height > 0:
                                self.current_width = new_width
                                self.current_height = new_height
                                print(f"✓ Image size set to: {new_width}x{new_height}")
                            else:
                                print("❌ Width and height must be positive integers")
                        except ValueError:
                            print("❌ Invalid size values. Please use integers (e.g., 'size 512 512')")
                    else:
                        print("❌ Please specify width and height (e.g., 'size 512 512')")
                    continue
                
                elif command == 'settings':
                    self.show_current_settings()
                    continue
                
                else:
                    # Treat as prompt for image generation
                    prompt = user_input
                    
                    print(f"\nGenerating image...")
                    image, output_path = self.generate_image(
                        prompt=prompt,
                        output_path=args.output,
                        guidance_scale=self.current_guidance_scale,
                        num_inference_steps=self.current_steps,
                        seed=args.seed,
                        width=self.current_width,
                        height=self.current_height
                    )
                    
                    if image is not None:
                        print(f"✓ Image generated successfully: {output_path}")
                    else:
                        print("❌ Image generation failed")
                
            except KeyboardInterrupt:
                print("\nExiting...")
                break
            except Exception as e:
                print(f"Error: {e}")
    
    def show_current_settings(self):
        """Display current interactive settings."""
        print(f"\nCurrent settings:")
        print(f"  Guidance scale: {self.current_guidance_scale}")
        print(f"  Inference steps: {self.current_steps}")
        print(f"  Image size: {self.current_width}x{self.current_height}")


def main():
    parser = argparse.ArgumentParser(description="Generate images using FLUX model")
    parser.add_argument("prompt", nargs="?", help="Text prompt for image generation")
    parser.add_argument("-o", "--output", help="Output file path (default: auto-generated)")
    parser.add_argument("-s", "--steps", type=int, default=4, 
                       help="Number of inference steps (default: 4)")
    parser.add_argument("-g", "--guidance", type=float, default=0.0,
                       help="Guidance scale (default: 0.0)")
    parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
    parser.add_argument("--width", type=int, default=224, help="Image width (default: 224)")
    parser.add_argument("--height", type=int, default=224, help="Image height (default: 224)")
    parser.add_argument("--model", default="black-forest-labs/FLUX.1-schnell",
                       help="Model to use (default: black-forest-labs/FLUX.1-schnell)")
    parser.add_argument("--interactive", "-i", action="store_false",
                       help="Interactive mode for multiple generations")
    
    args = parser.parse_args()
    
    # Initialize the generator
    generator = FluxImageGenerator(
        model_name=args.model,
        use_cpu_offload=False,
        enable_memory_efficient=True,
    )
    
    if args.interactive:
        # Interactive mode
        generator.interactive_mode(args)
    else:
        # Single generation mode
        prompt = args.prompt
        if not prompt:
            prompt = input("Enter your prompt: ").strip()
            if not prompt:
                print("Error: No prompt provided")
                sys.exit(1)
        
        # Generate image
        image, output_path = generator.generate_image(
            prompt=prompt,
            output_path=args.output,
            guidance_scale=args.guidance,
            num_inference_steps=args.steps,
            seed=args.seed,
            width=args.width,
            height=args.height
        )
        
        if image is not None:
            print(f"✓ Image generated successfully: {output_path}")
        else:
            print("✗ Image generation failed")
            sys.exit(1)


if __name__ == "__main__":
    main()
