#!/usr/bin/env python3
"""
Image Generator Script for Tree-based Dataset

This script generates images for each node in the tree structure of the dataset
using the FLUX text-to-image model from chat_image_gen.py.
"""

import os
import json
import sys
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple

from config import CS_DJ_parser, parse_categories
import torch
import datetime
from diffusers import FluxPipeline

class FluxImageGenerator:
    def __init__(self, model_name="black-forest-labs/FLUX.1-schnell", use_cpu_offload=False, 
                 enable_memory_efficient=True):
        """Initialize the FLUX image generator."""
        self.model_name = model_name
        self.use_cpu_offload = use_cpu_offload
        self.enable_memory_efficient = enable_memory_efficient
        self.pipe = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Interactive mode settings
        self.current_guidance_scale = 0.0
        self.current_steps = 4
        self.current_width = 224
        self.current_height = 224
        
        print(f"Using device: {self.device}")
        print(f"Loading model: {model_name}")
        
        # Check available memory
        if self.device == "cuda":
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print(f"GPU Memory: {gpu_memory:.1f} GB")

        self.load_model()
        
    def load_model(self):
        """Load the FLUX model with memory optimizations."""
        try:
            # Clear any cached models
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            # Load the pipeline with memory optimizations
            load_kwargs = {
                "torch_dtype": torch.bfloat16 if self.device == "cuda" else torch.float32,
                "low_cpu_mem_usage": True,
            }
            
            self.pipe = FluxPipeline.from_pretrained(self.model_name, **load_kwargs)
            
            # Apply memory optimizations
            if self.device == "cuda":
                if self.use_cpu_offload:
                    self.pipe.enable_model_cpu_offload()
                    print("✓ Model CPU offload enabled")
                else:
                    self.pipe = self.pipe.to("cuda")
                        
            print("Model loaded successfully!")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print("\nTry running with different settings for lower memory usage")
            sys.exit(1)
    
    def generate_image(self, prompt, output_path=None, guidance_scale=0.0, 
                      num_inference_steps=4, max_sequence_length=256, seed=None,
                      width=224, height=224):
        """Generate an image from a text prompt."""
        if self.pipe is None:
            self.load_model()
        
        # Set up the generator with seed for reproducibility
        generator = None
        if seed is not None:
            generator = torch.Generator(self.device).manual_seed(seed)
        
        # Generate output filename if not provided
        if output_path is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_path = f"flux_generated_{timestamp}.png"
        
        # Create output directory if it doesn't exist
        output_dir = Path(output_path).parent
        output_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"Generating image with prompt: '{prompt}'")
        print(f"Parameters: steps={num_inference_steps}, guidance_scale={guidance_scale}")
        print(f"Image size: {width}x{height}")
        if seed is not None:
            print(f"Seed: {seed}")
        
        try:
            # Clear cache before generation
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            # Generate the image
            image = self.pipe(
                prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                max_sequence_length=max_sequence_length,
                generator=generator,
                width=width,
                height=height,
            ).images[0]
            
            # Save the image
            image.save(output_path)
            print(f"Image saved to: {output_path}")
            
            # Clear cache after generation
            if hasattr(torch.cuda, 'empty_cache'):
                torch.cuda.empty_cache()
            
            return image, output_path
            
        except Exception as e:
            print(f"Error generating image: {e}")
            if "out of memory" in str(e).lower():
                print("\n💡 Memory optimization suggestions:")
                print("   - Reduce image size with smaller --width and --height")
                print("   - Close other applications to free memory")
            return None, None

    def interactive_mode(self, args):
        """Interactive mode with changeable parameters."""
        print("=== FLUX Image Generator - Interactive Mode ===")
        print("Commands:")
        print("  help - Show available commands")
        print("  guidance <value> - Set guidance scale (e.g., 'guidance 3.5')")
        print("  steps <value> - Set inference steps (e.g., 'steps 8')")
        print("  size <width> <height> - Set image size (e.g., 'size 512 512')")
        print("  settings - Show current settings")
        print("  quit/exit - Exit the program")
        print("  <prompt> - Generate image from prompt")
        print()
        
        # Initialize current settings
        self.current_guidance_scale = args.guidance
        self.current_steps = args.steps
        self.current_width = args.width
        self.current_height = args.height
        
        self.show_current_settings()
        
        while True:
            try:
                user_input = input("\n> ").strip()
                
                if not user_input:
                    continue
                
                # Parse commands
                parts = user_input.split()
                command = parts[0].lower()
                
                if command in ['quit', 'exit']:
                    print("Goodbye!")
                    break
                
                elif command == 'help':
                    print("\nAvailable commands:")
                    print("  help - Show this help")
                    print("  guidance <value> - Set guidance scale (e.g., 'guidance 3.5')")
                    print("  steps <value> - Set inference steps (e.g., 'steps 8')")
                    print("  size <width> <height> - Set image size (e.g., 'size 512 512')")
                    print("  settings - Show current settings")
                    print("  quit/exit - Exit the program")
                    print("  <any other text> - Generate image from prompt")
                    continue
                
                elif command == 'guidance':
                    if len(parts) >= 2:
                        try:
                            new_guidance = float(parts[1])
                            self.current_guidance_scale = new_guidance
                            print(f"✓ Guidance scale set to: {new_guidance}")
                        except ValueError:
                            print("❌ Invalid guidance value. Please use a number (e.g., 'guidance 3.5')")
                    else:
                        print("❌ Please specify a guidance value (e.g., 'guidance 3.5')")
                    continue
                
                elif command == 'steps':
                    if len(parts) >= 2:
                        try:
                            new_steps = int(parts[1])
                            if new_steps > 0:
                                self.current_steps = new_steps
                                print(f"✓ Inference steps set to: {new_steps}")
                            else:
                                print("❌ Steps must be a positive integer")
                        except ValueError:
                            print("❌ Invalid steps value. Please use an integer (e.g., 'steps 8')")
                    else:
                        print("❌ Please specify number of steps (e.g., 'steps 8')")
                    continue
                
                elif command == 'size':
                    if len(parts) >= 3:
                        try:
                            new_width = int(parts[1])
                            new_height = int(parts[2])
                            if new_width > 0 and new_height > 0:
                                self.current_width = new_width
                                self.current_height = new_height
                                print(f"✓ Image size set to: {new_width}x{new_height}")
                            else:
                                print("❌ Width and height must be positive integers")
                        except ValueError:
                            print("❌ Invalid size values. Please use integers (e.g., 'size 512 512')")
                    else:
                        print("❌ Please specify width and height (e.g., 'size 512 512')")
                    continue
                
                elif command == 'settings':
                    self.show_current_settings()
                    continue
                
                else:
                    # Treat as prompt for image generation
                    prompt = user_input
                    
                    print(f"\nGenerating image...")
                    image, output_path = self.generate_image(
                        prompt=prompt,
                        output_path=args.output,
                        guidance_scale=self.current_guidance_scale,
                        num_inference_steps=self.current_steps,
                        seed=args.seed,
                        width=self.current_width,
                        height=self.current_height
                    )
                    
                    if image is not None:
                        print(f"✓ Image generated successfully: {output_path}")
                    else:
                        print("❌ Image generation failed")
                
            except KeyboardInterrupt:
                print("\nExiting...")
                break
            except Exception as e:
                print(f"Error: {e}")
    
    def show_current_settings(self):
        """Display current interactive settings."""
        print(f"\nCurrent settings:")
        print(f"  Guidance scale: {self.current_guidance_scale}")
        print(f"  Inference steps: {self.current_steps}")
        print(f"  Image size: {self.current_width}x{self.current_height}")

# Data paths
data_path = "./processed_results"

# Default image generation parameters
DEFAULT_WIDTH = 224
DEFAULT_HEIGHT = 224
DEFAULT_GUIDANCE_SCALE = 10.0
DEFAULT_INFERENCE_STEPS = 20


class ImageGeneratorForDataset:
    def __init__(self, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, 
                 guidance_scale: float = DEFAULT_GUIDANCE_SCALE, 
                 inference_steps: int = DEFAULT_INFERENCE_STEPS):
        """Initialize the image generator for dataset processing."""
        self.width = width
        self.height = height
        self.guidance_scale = guidance_scale
        self.inference_steps = inference_steps
        
        self.initial_prompt = ""
        # Initialize FLUX generator
        print("Initializing FLUX image generator...")
        self.flux_generator = FluxImageGenerator(
            model_name="black-forest-labs/FLUX.1-schnell",
            use_cpu_offload=False,
            enable_memory_efficient=True
        )
        print("FLUX generator initialized successfully!")
    
    def traverse_and_generate_images(self, node: Dict[str, Any], output_dir: str, 
                                   node_path: str = "") -> int:
        """
        Recursively traverse the tree structure and generate images for each node.
        
        Args:
            node: Current node in the tree structure
            output_dir: Directory to save generated images
            node_path: Path to current node (for naming)
            
        Returns:
            Number of images generated
        """
        images_generated = 0
        
        # Generate image for current node
        if "name" in node:
            node_name = node["name"]
            
            # Create output path for this node
            if node_path:
                current_node_path = f"{node_path}_{len(node_path.split('_'))}"
            else:
                current_node_path = "0"
            
            output_path = os.path.join(output_dir, f"{current_node_path}.png")
            
            # Generate image
            print(f"Generating image for node: '{node_name}'")
            try:
                image, saved_path = self.flux_generator.generate_image(
                    prompt=f'Anime key visual of "{node_name}", serving the overall goal of "{self.initial_prompt}" Art style tags: anime, cel-shaded, crisp line art, vibrant colors, detailed shading, studio-quality illustration, key visual, 2D digital painting. No text, no watermarks.',
                    output_path=output_path,
                    guidance_scale=self.guidance_scale,
                    num_inference_steps=self.inference_steps,
                    width=self.width,
                    height=self.height
                )
                
                if image is not None:
                    # Add image path to node
                    node["img"] = saved_path
                    images_generated += 1
                    print(f"✓ Generated image for '{node_name}' -> {saved_path}")
                else:
                    print(f"✗ Failed to generate image for '{node_name}'")
                    node["img"] = None
                    
            except Exception as e:
                print(f"✗ Error generating image for '{node_name}': {e}")
                node["img"] = None
        
        # Recursively process children
        if "children" in node and isinstance(node["children"], list):
            for i, child in enumerate(node["children"]):
                child_path = f"{node_path}_{i}" if node_path else str(i)
                images_generated += self.traverse_and_generate_images(
                    child, output_dir, child_path
                )
        
        return images_generated
    
    def process_category(self, category: str, output_base_dir: str) -> Tuple[int, int]:
        """
        Process all datapoints in a category.
        
        Args:
            category: Category name
            output_base_dir: Base directory for output images
            
        Returns:
            Tuple of (datapoints_processed, total_images_generated)
        """
        json_file = os.path.join(data_path, f"processed_{category}_with_trees_enhanced.json")
        
        if not os.path.exists(json_file):
            print(f"Warning: JSON file not found for category {category}: {json_file}")
            return 0, 0
        
        print(f"\n=== Processing category: {category} ===")
        
        # Load JSON data
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        if not isinstance(data, list):
            print(f"Error: Expected list format for {category} data")
            return 0, 0
        
        # data = data[:10]
        datapoints_processed = 0
        total_images_generated = 0
        
        # Process each datapoint
        for data_idx, datapoint in enumerate(data):
            print(f"\nProcessing datapoint {data_idx + 1}/{len(data)} for {category}")
            
            # Check if structured data exists
            if "structured" not in datapoint:
                print(f"Warning: No 'structured' field found in datapoint {data_idx}")
                continue
            
            structured_data = datapoint["structured"]
            if not isinstance(structured_data, dict):
                print(f"Warning: 'structured' field is not a dictionary in datapoint {data_idx}")
                continue
            
            # Create output directory for this datapoint
            output_dir = os.path.join(output_base_dir, category, str(data_idx))
            os.makedirs(output_dir, exist_ok=True)
            
            self.initial_prompt = datapoint["question"]
            # Generate images for all nodes in the tree
            images_generated = self.traverse_and_generate_images(
                structured_data, output_dir
            )
            
            total_images_generated += images_generated
            datapoints_processed += 1
            
            print(f"✓ Processed datapoint {data_idx}: {images_generated} images generated")
        
        # Save updated JSON with image paths
        output_json = os.path.join(data_path, f"processed_{category}_with_trees_and_images.json")
        with open(output_json, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        
        print(f"✓ Updated JSON saved to: {output_json}")
        
        return datapoints_processed, total_images_generated
    
    def process_all_categories(self, output_base_dir: str) -> None:
        """Process all categories in the dataset."""
        print("=== Starting image generation for all categories ===")
        
        total_datapoints = 0
        total_images = 0
        
        for category in category_list:
            try:
                datapoints, images = self.process_category(category, output_base_dir)
                total_datapoints += datapoints
                total_images += images
                print(f"✓ Category {category} completed: {datapoints} datapoints, {images} images")
            except Exception as e:
                print(f"✗ Error processing category {category}: {e}")
                continue
        
        print(f"\n=== Generation Complete ===")
        print(f"Total datapoints processed: {total_datapoints}")
        print(f"Total images generated: {total_images}")


def main():
    parser = CS_DJ_parser()
    parser.add_argument("--output-dir", default="generated_images", 
                       help="Base directory for generated images (default: generated_images)")
    parser.add_argument("--width", type=int, default=DEFAULT_WIDTH,
                       help=f"Image width (default: {DEFAULT_WIDTH})")
    parser.add_argument("--height", type=int, default=DEFAULT_HEIGHT,
                       help=f"Image height (default: {DEFAULT_HEIGHT})")
    parser.add_argument("--guidance-scale", type=float, default=DEFAULT_GUIDANCE_SCALE,
                       help=f"Guidance scale for image generation (default: {DEFAULT_GUIDANCE_SCALE})")
    parser.add_argument("--inference-steps", type=int, default=DEFAULT_INFERENCE_STEPS,
                       help=f"Number of inference steps (default: {DEFAULT_INFERENCE_STEPS})")
    
    args = parser.parse_args()
    global category_list
    category_list = parse_categories(args)
    
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize image generator
    try:
        generator = ImageGeneratorForDataset(
            width=args.width,
            height=args.height,
            guidance_scale=args.guidance_scale,
            inference_steps=args.inference_steps
        )
    except Exception as e:
        print(f"Error initializing image generator: {e}")
        sys.exit(1)
    
    generator.process_all_categories(args.output_dir)


if __name__ == "__main__":
    main()

