"""
Iterative image generation script that processes the dataset by type and difficulty level.
Generates images and organizes them in folders by API model, type and level.
Tracks generation times for each problem.
"""

import pandas as pd
import os
from get_image import T2IModelTester
from collections import defaultdict
import json
from datetime import datetime
import time # Added for time.sleep

class IterativeImageGenerator:
    """
    Generates images iteratively from the dataset, organized by API model, type and difficulty level.
    """
    
    def __init__(self, csv_path: str, base_output_dir: str = "generated_images"):
        """
        Initialize the iterative image generator.
        
        Args:
            csv_path: Path to the balanced_sample_1000.csv file
            base_output_dir: Base directory for organizing generated images
        """
        self.csv_path = csv_path
        self.base_output_dir = base_output_dir
        self.generation_stats = defaultdict(int)
        self.failed_generations = []
        self.generation_times = defaultdict(list)  # Track times by API model
        
        # Create base directory structure
        os.makedirs(base_output_dir, exist_ok=True)
        
        # Load and analyze dataset
        self.df = pd.read_csv(csv_path)
        self._analyze_dataset()
    
    def _analyze_dataset(self):
        """Analyze the dataset to understand type and level distribution."""
        print("📊 Dataset Analysis:")
        print(f"Total problems: {len(self.df)}")
        
        # Count by type
        type_counts = self.df['type'].value_counts()
        print("\nBy Type:")
        for ptype, count in type_counts.items():
            print(f"  {ptype}: {count}")
        
        # Count by level
        level_counts = self.df['level'].value_counts()
        print("\nBy Level:")
        for level, count in level_counts.items():
            print(f"  {level}: {count}")
        
        # Count by type and level combination
        print("\nBy Type and Level:")
        type_level_counts = self.df.groupby(['type', 'level']).size()
        for (ptype, level), count in type_level_counts.items():
            print(f"  {ptype} - {level}: {count}")
    
    def _create_directory_structure(self, api_model: str, problem_type: str, level: str) -> str:
        """
        Create directory structure for organizing images by API model, type and level.
        
        Args:
            api_model: API model used (openai, gemini, stability)
            problem_type: Type of problem (Algebra, Geometry, etc.)
            level: Difficulty level (Level 1, Level 2, etc.)
        
        Returns:
            Path to the created directory
        """
        # Clean names for directory structure
        clean_api = api_model.lower()
        clean_type = problem_type.lower().replace(" ", "_")
        clean_level = level.lower().replace(" ", "_")

        # Create directory path: base_output_dir/api_model/type/level/
        dir_path = os.path.join(self.base_output_dir, clean_api, clean_type, clean_level)
        os.makedirs(dir_path, exist_ok=True)

        return dir_path
    
    def _get_output_path(self, api_provider: str, problem_type: str, level: str, 
                        original_index: int, sequence_num: int) -> str:
        """
        Get the output path for an image file.
        
        Args:
            api_provider: API provider used
            problem_type: Type of problem
            level: Difficulty level
            original_index: Original index in dataset
            sequence_num: Sequence number for this type/level combination
        
        Returns:
            Full path to the output file
        """
        # Clean names for folder structure
        clean_type = problem_type.lower().replace(" ", "_").replace("&", "and")
        clean_level = level.lower().replace(" ", "_")
        
        # Create hierarchical structure: api_provider/type/level/
        output_dir = os.path.join(
            self.base_output_dir,
            api_provider,
            clean_type,
            f"{clean_level}"
        )
        
        # Create directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Generate filename: sequence_number_type_level_original_index.png
        filename = f"{sequence_num:03d}_{clean_type}_{clean_level}_{original_index}.png"
        
        return os.path.join(output_dir, filename)
    
    def _check_file_exists(self, api_provider: str, problem_type: str, level: str, 
                          original_index: int, sequence_num: int) -> tuple:
        """
        Check if the image file already exists.
        
        Args:
            api_provider: API provider used
            problem_type: Type of problem
            level: Difficulty level
            original_index: Original index in dataset
            sequence_num: Sequence number for this type/level combination
        
        Returns:
            Tuple of (exists: bool, file_path: str)
        """
        file_path = self._get_output_path(api_provider, problem_type, level, original_index, sequence_num)
        return os.path.exists(file_path), file_path
    
    def _save_generated_image(self, result: dict, original_path: str, output_path: str) -> bool:
        """
        Move or copy the generated image to the organized directory structure.
        
        Args:
            result: Result dictionary from image generation
            original_path: Path where the image was originally saved by test_single_problem
            output_path: Full path where the image should be moved/copied to
        
        Returns:
            True if successful, False otherwise
        """
        try:
            import shutil
            import os
            
            # If the original file exists, move it to the organized location
            if os.path.exists(original_path):
                # Create the output directory if it doesn't exist
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
                
                # Move the file to the organized location
                shutil.move(original_path, output_path)
                return True
            else:
                print(f"    ⚠️  Original image file not found: {original_path}")
                return False
                
        except Exception as e:
            print(f"    ⚠️  Error moving image: {str(e)}")
            return False
    
    def generate_by_type_and_level(self, problem_type: str, level: str,
                                 max_problems: int = 10, style: str = "blackboard",
                                 api_provider: str = "openai", skip_existing: bool = True) -> list:
        """
        Generate images for a specific problem type and level using specified API.

        Args:
            problem_type: Type of problem to generate
            level: Difficulty level to generate
            max_problems: Maximum number of problems to generate
            style: Visual style for generation
            api_provider: API provider to use (openai, gemini, stability, flux, wan22, kling)
            skip_existing: Whether to skip generation if file already exists

        Returns:
            List of generation results
        """
        # Filter dataset for the specific type and level
        filtered_df = self.df[
            (self.df['type'] == problem_type) & 
            (self.df['level'] == level)
        ]
        
        if len(filtered_df) == 0:
            print(f"❌ No problems found for {problem_type} - {level}")
            return []
        
        # Limit to max_problems
        num_to_generate = min(max_problems, len(filtered_df))
        print(f"\n🎯 Generating {num_to_generate} images for {problem_type} - {level} using {api_provider}")
        
        # Create directory for this API/type/level combination
        output_dir = self._create_directory_structure(api_provider, problem_type, level)
        
        # Initialize tester - use the base output directory for tester, we'll handle organization ourselves
        tester = T2IModelTester(self.csv_path, self.base_output_dir)
        
        results = []
        skipped_count = 0
        generated_count = 0
        
        for i in range(num_to_generate):
            # Get the actual index in the original dataset
            original_index = filtered_df.iloc[i].name
            
            # Get the organized output path
            output_path = self._get_output_path(api_provider, problem_type, level, original_index, i)
            
            # Check if file already exists
            if skip_existing and os.path.exists(output_path):
                print(f"  ⏭️  Skipping image {i+1}/{num_to_generate} (index {original_index}) - File exists")
                print(f"      {output_path}")
                
                # Create a mock result for existing file
                result = {
                    "index": original_index,
                    "problem": filtered_df.iloc[i]['problem'],
                    "solution": filtered_df.iloc[i]['solution'],
                    "level": level,
                    "type": problem_type,
                    "style": style,
                    "api_provider": api_provider,
                    "generation_time": 0.0,
                    "image_data": "existing_file",
                    "success": True,
                    "timestamp": datetime.now().isoformat(),
                    "skipped": True
                }
                results.append(result)
                skipped_count += 1
                continue
            
            # Check if the original file (that test_single_problem would create) already exists
            clean_type = problem_type.lower().replace(" ", "_")
            clean_level = level.lower().replace(" ", "_")
            original_filename = f"{clean_type}_{clean_level}_{original_index}.png"
            original_path = os.path.join(self.base_output_dir, original_filename)
            
            # If the original file already exists, move it to organized location
            if os.path.exists(original_path):
                print(f"  📁 Found existing image {i+1}/{num_to_generate} (index {original_index}) - Moving to organized location")
                
                # Move the existing image to organized location
                saved = self._save_generated_image({}, original_path, output_path)
                
                if saved:
                    result = {
                        "index": original_index,
                        "problem": filtered_df.iloc[i]['problem'],
                        "solution": filtered_df.iloc[i]['solution'],
                        "level": level,
                        "type": problem_type,
                        "style": style,
                        "api_provider": api_provider,
                        "generation_time": 0.0,
                        "success": True,
                        "timestamp": datetime.now().isoformat(),
                        "existing_moved": True
                    }
                    self.generation_stats[f"{api_provider}_{problem_type}_{level}"] += 1
                    generated_count += 1
                    print(f"    ✅ Moved existing image")
                else:
                    print(f"    ⚠️  Failed to move existing image")
                    result = {"success": False, "error": "Failed to move existing image"}
                
                results.append(result)
                continue
            
            print(f"  Generating image {i+1}/{num_to_generate} (index {original_index})")
            
            # Generate image with retry logic for API errors
            max_retries = 2
            retry_delay = 2  # seconds
            result = None
            
            for retry in range(max_retries + 1):
                try:
                    result = tester.test_single_problem(
                        index=original_index,
                        style=style,
                        api_provider=api_provider,
                        save_image=True  # Let test_single_problem save it first
                    )
                    
                    # If we got a 500 error, retry after delay
                    if not result.get('success') and '500' in str(result.get('error', '')):
                        if retry < max_retries:
                            print(f"    ⚠️  Got 500 error, retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                            time.sleep(retry_delay)
                            retry_delay *= 2  # Exponential backoff
                            continue
                    break  # Success or non-500 error, don't retry
                    
                except Exception as e:
                    if retry < max_retries:
                        print(f"    ⚠️  Exception occurred, retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                        continue
                    else:
                        result = {'success': False, 'error': str(e)}
                        break
            
            if result and result.get('success'):
                # Build the original filename that test_single_problem would have created
                clean_type = problem_type.lower().replace(" ", "_")
                clean_level = level.lower().replace(" ", "_")
                original_filename = f"{clean_type}_{clean_level}_{original_index}.png"
                original_path = os.path.join(self.base_output_dir, original_filename)

                # Check if the file was actually saved by test_single_problem
                if os.path.exists(original_path):
                    # Move the image to organized location
                    saved = self._save_generated_image(result, original_path, output_path)
                    
                    if saved:
                        self.generation_stats[f"{api_provider}_{problem_type}_{level}"] += 1
                        
                        # Track generation time
                        generation_time = result.get('generation_time', 0)
                        self.generation_times[api_provider].append(generation_time)
                        
                        print(f"    ✅ Success - {generation_time:.2f}s")
                        generated_count += 1
                    else:
                        print(f"    ⚠️  Generated but failed to move image")
                else:
                    print(f"    ⚠️  Generated but file not found at expected location: {original_path}")
                    # Still count as failed since we can't find the file
                    self.failed_generations.append({
                        'index': original_index,
                        'type': problem_type,
                        'level': level,
                        'api_provider': api_provider,
                        'error': 'Generated but file not found at expected location'
                    })
            else:
                print(f"    ❌ Failed: {result.get('error', 'Unknown error')}")
                self.failed_generations.append({
                    'index': original_index,
                    'type': problem_type,
                    'level': level,
                    'api_provider': api_provider,
                    'error': result.get('error', 'Unknown error')
                })
            
            results.append(result)
        
        # Print summary for this combination
        if skip_existing and skipped_count > 0:
            print(f"  📊 Summary: {generated_count} generated, {skipped_count} skipped")
        
        return results
    
    def generate_test_set(self, api_providers: list = None, skip_existing: bool = True) -> dict:
        """
        Generate images for ALL problems in the dataset using multiple APIs.

        Args:
            api_providers: List of API providers to test (default: ['openai', 'gemini'])
            skip_existing: Whether to skip generation if file already exists

        Returns:
            Dictionary with generation results
        """
        if api_providers is None:
            api_providers = ['openai', 'gemini']
        
        print("🚀 Starting FULL dataset generation...")
        print(f"📊 Total problems: {len(self.df)}")
        print(f" API providers: {', '.join(api_providers)}")
        print(f"⏭️  Skip existing: {skip_existing}")
        
        # Initialize results tracking
        results = {
            'total_problems': len(self.df),
            'api_providers': api_providers,
            'generation_times': {},
            'success_count': 0,
            'skip_count': 0,
            'error_count': 0,
            'start_time': datetime.now().isoformat(),
            'problems': []
        }
        
        # Process each problem
        for idx, row in self.df.iterrows():
            problem_type = row['type']
            level = row['level']
            problem_text = row['problem']
            
            print(f"\n📝 Processing problem {idx + 1}/{len(self.df)}: {problem_type} - {level}")
            
            # Process with each API provider
            for api_provider in api_providers:
                # Get the output path
                output_path = self._get_output_path(api_provider, problem_type, level, idx, 0)
                
                # Check if file already exists
                if skip_existing and os.path.exists(output_path):
                    print(f"⏭️  Skipping {api_provider} - file exists: {os.path.basename(output_path)}")
                    results['skip_count'] += 1
                    continue
                
                # Check if the original file (that test_single_problem would create) already exists
                clean_type = problem_type.lower().replace(" ", "_")
                clean_level = level.lower().replace(" ", "_")
                original_filename = f"{clean_type}_{clean_level}_{idx}.png"
                original_path = os.path.join(self.base_output_dir, original_filename)
                
                # If the original file already exists, move it to organized location
                if os.path.exists(original_path):
                    print(f"📁 {api_provider}: Found existing image - Moving to organized location")
                    
                    # Move the existing image to organized location
                    saved = self._save_generated_image({}, original_path, output_path)
                    
                    if saved:
                        print(f"✅ {api_provider}: Moved existing image")
                        results['success_count'] += 1
                    else:
                        print(f"⚠️  {api_provider}: Failed to move existing image")
                        results['error_count'] += 1
                    continue
                
                # Initialize tester for this API
                tester = T2IModelTester(self.csv_path, self.base_output_dir)
                
                # Generate image
                start_time = time.time()
                try:
                    result = tester.test_single_problem(
                        index=idx,
                        style="blackboard",
                        api_provider=api_provider,
                        save_image=True  # Let test_single_problem save it first
                    )
                    
                    generation_time = time.time() - start_time
                    results['generation_times'][f"{api_provider}_{idx}"] = generation_time
                    
                    if result.get('success'):
                        # Build the original filename that test_single_problem would have created
                        clean_type = problem_type.lower().replace(" ", "_")
                        clean_level = level.lower().replace(" ", "_")
                        original_filename = f"{clean_type}_{clean_level}_{idx}.png"
                        original_path = os.path.join(self.base_output_dir, original_filename)
                        
                        # Check if the file was actually saved by test_single_problem
                        if os.path.exists(original_path):
                            # Move the image to organized location
                            saved = self._save_generated_image(result, original_path, output_path)
                            
                            if saved:
                                print(f"✅ {api_provider}: Generated in {generation_time:.2f}s")
                                results['success_count'] += 1
                            else:
                                print(f"⚠️  {api_provider}: Generated but failed to move image")
                                results['error_count'] += 1
                        else:
                            print(f"⚠️  {api_provider}: Generated but file not found at expected location: {original_path}")
                            results['error_count'] += 1
                    else:
                        print(f"❌ {api_provider}: Failed - {result.get('error', 'Unknown error')}")
                        results['error_count'] += 1
                        
                except Exception as e:
                    generation_time = time.time() - start_time
                    results['generation_times'][f"{api_provider}_{idx}"] = generation_time
                    print(f"❌ {api_provider}: Error - {str(e)}")
                    results['error_count'] += 1
                
                # Add small delay between API calls
                time.sleep(0.5)
            
            # Add problem info to results
            results['problems'].append({
                'index': idx,
                'type': problem_type,
                'level': level,
                'problem': problem_text[:100] + "..." if len(problem_text) > 100 else problem_text
            })
        
        # Finalize results
        results['end_time'] = datetime.now().isoformat()
        results['total_time'] = (datetime.fromisoformat(results['end_time']) - 
                               datetime.fromisoformat(results['start_time'])).total_seconds()
        
        # Save report
        report_path = self.save_generation_report(results)
        print(f"\n🎉 Generation complete!")
        print(f"✅ Success: {results['success_count']}")
        print(f"⏭️  Skipped: {results['skip_count']}")
        print(f"❌ Errors: {results['error_count']}")
        print(f"⏱️  Total time: {results['total_time']:.2f}s")
        print(f"📄 Report saved: {report_path}")
        
        return results
    
    def generate_by_priority(self, priority_combinations: list, problems_per_combination: int = 5,
                           api_providers: list = None, skip_existing: bool = True) -> dict:
        """
        Generate images for specific type/level combinations in priority order using multiple APIs.

        Args:
            priority_combinations: List of (type, level) tuples in priority order
            problems_per_combination: Number of problems to generate per combination
            api_providers: List of API providers to test (openai, gemini, stability, flux, wan22, kling)
            skip_existing: Whether to skip generation if file already exists

        Returns:
            Dictionary with generation results
        """
        if api_providers is None:
            api_providers = ['openai', 'gemini']
        
        print("🎯 Starting priority-based generation...")
        print(f"Testing APIs: {', '.join(api_providers)}")
        if skip_existing:
            print("⏭️  Skipping existing files")
        
        all_results = {}
        
        for api_provider in api_providers:
            print(f"\n Testing {api_provider.upper()} API")
            print("=" * 50)
            
            api_results = {}
            
            for problem_type, level in priority_combinations:
                print(f"\n📚 Processing {problem_type} - {level}")
                
                results = self.generate_by_type_and_level(
                    problem_type=problem_type,
                    level=level,
                    max_problems=problems_per_combination,
                    style="blackboard",
                    api_provider=api_provider,
                    skip_existing=skip_existing
                )
                
                api_results[f"{problem_type}_{level}"] = results
            
            all_results[api_provider] = api_results
        
        return all_results
    
    def generate_all_problems_single_api(self, api_provider: str = "openai", skip_existing: bool = True) -> dict:
        """
        Generate images for ALL problems using a single API (faster for large datasets).

        Args:
            api_provider: API provider to use (openai, gemini, stability, flux, wan22, kling)
            skip_existing: Whether to skip generation if file already exists

        Returns:
            Dictionary with generation results
        """
        print("🚀 Starting FULL dataset generation with single API...")
        print(f"Total problems in dataset: {len(self.df)}")
        print(f"Using API: {api_provider.upper()}")
        if skip_existing:
            print("⏭️  Skipping existing files")
        
        # Initialize tester - use base output directory for tester
        tester = T2IModelTester(self.csv_path, self.base_output_dir)
        
        results = []
        skipped_count = 0
        generated_count = 0
        failed_count = 0
        
        # Track sequence numbers for each type/level combination
        sequence_counters = defaultdict(int)
        
        print(f"\n Processing all {len(self.df)} problems...")
        print("=" * 60)
        
        for i, (_, row) in enumerate(self.df.iterrows()):
            problem_type = row['type']
            level = row['level']
            original_index = row.name
            
            # Get sequence number for this type/level combination
            type_level_key = f"{problem_type}_{level}"
            sequence_num = sequence_counters[type_level_key]
            
            # Get the organized output path
            output_path = self._get_output_path(api_provider, problem_type, level, original_index, sequence_num)
            
            # Check if file already exists
            if skip_existing and os.path.exists(output_path):
                if i % 10 == 0:  # Show progress every 10 problems
                    print(f"  ⏭️  Progress: {i+1}/{len(self.df)} - Skipping existing file")
                
                # Create a mock result for existing file
                result = {
                    "index": original_index,
                    "problem": row['problem'],
                    "solution": row['solution'],
                    "level": level,
                    "type": problem_type,
                    "style": "blackboard",
                    "api_provider": api_provider,
                    "generation_time": 0.0,
                    "image_data": "existing_file",
                    "success": True,
                    "timestamp": datetime.now().isoformat(),
                    "skipped": True
                }
                results.append(result)
                skipped_count += 1
                sequence_counters[type_level_key] += 1
                continue
            # Check if the original file (that test_single_problem would create) already exists
            clean_type = problem_type.lower().replace(" ", "_")
            clean_level = level.lower().replace(" ", "_")
            original_filename = f"{clean_type}_{clean_level}_{original_index}.png"
            original_path = os.path.join(self.base_output_dir, original_filename)
            
            # If the original file already exists, move it to organized location
            if os.path.exists(original_path):
                print(f"  📁 Found existing image for {i+1}/{len(self.df)} - Moving to organized location")
                
                # Move the existing image to organized location
                saved = self._save_generated_image({}, original_path, output_path)
                
                if saved:
                    result = {
                        "index": original_index,
                        "problem": row['problem'],
                        "solution": row['solution'],
                        "level": level,
                        "type": problem_type,
                        "style": "blackboard",
                        "api_provider": api_provider,
                        "generation_time": 0.0,
                        "success": True,
                        "timestamp": datetime.now().isoformat(),
                        "existing_moved": True
                    }
                    self.generation_stats[f"{api_provider}_{problem_type}_{level}"] += 1
                    generated_count += 1
                    sequence_counters[type_level_key] += 1
                else:
                    print(f"    ⚠️  Failed to move existing image")
                    result = {"success": False, "error": "Failed to move existing image"}
                    failed_count += 1
                
                results.append(result)
                continue
            
            # Show progress every 10 problems
            if i % 5 == 0:
                print(f"  Generating: {i+1}/{len(self.df)} - {problem_type} {level}")
            
            # Generate image with retry logic for API errors
            max_retries = 2
            retry_delay = 2  # seconds
            result = None
            
            for retry in range(max_retries + 1):
                try:
                    result = tester.test_single_problem(
                        index=original_index,
                        style="blackboard",
                        api_provider=api_provider,
                        save_image=True  # Let test_single_problem save it first
                    )
                    
                    # If we got a 500 error, retry after delay
                    if not result.get('success') and '500' in str(result.get('error', '')):
                        if retry < max_retries:
                            print(f"    ⚠️  Got 500 error, retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                            time.sleep(retry_delay)
                            retry_delay *= 2  # Exponential backoff
                            continue
                    break  # Success or non-500 error, don't retry
                    
                except Exception as e:
                    if retry < max_retries:
                        print(f"    ⚠️  Exception occurred, retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                        continue
                    else:
                        result = {'success': False, 'error': str(e)}
                        break
            
            if result and result.get('success'):
                # Build the original filename that test_single_problem would have created
                clean_type = problem_type.lower().replace(" ", "_")
                clean_level = level.lower().replace(" ", "_")
                original_filename = f"{clean_type}_{clean_level}_{original_index}.png"
                original_path = os.path.join(self.base_output_dir, original_filename)
                
                # Check if the file was actually saved by test_single_problem
                if os.path.exists(original_path):
                    # Move the image to organized location
                    saved = self._save_generated_image(result, original_path, output_path)
                    
                    if saved:
                        self.generation_stats[f"{api_provider}_{problem_type}_{level}"] += 1
                        
                        # Track generation time
                        generation_time = result.get('generation_time', 0)
                        self.generation_times[api_provider].append(generation_time)
                        
                        generated_count += 1
                        sequence_counters[type_level_key] += 1
                    else:
                        print(f"    ⚠️  Generated but failed to move image")
                        failed_count += 1
                else:
                    print(f"    ⚠️  Generated but file not found at expected location: {original_path}")
                    # Still count as failed since we can't find the file
                    self.failed_generations.append({
                        'index': original_index,
                        'type': problem_type,
                        'level': level,
                        'api_provider': api_provider,
                        'error': 'Generated but file not found at expected location'
                    })
                    failed_count += 1
            else:
                print(f"    ❌ Failed: {result.get('error', 'Unknown error')}")
                self.failed_generations.append({
                    'index': original_index,
                    'type': problem_type,
                    'level': level,
                    'api_provider': api_provider,
                    'error': result.get('error', 'Unknown error')
                })
                failed_count += 1
            
            results.append(result)
            
            # Show progress summary every 50 problems
            if (i + 1) % 50 == 0:
                print(f"   Progress: {i+1}/{len(self.df)} - Generated: {generated_count}, Skipped: {skipped_count}, Failed: {failed_count}")
        
        print(f"\n🎉 Generation completed!")
        print(f"  📊 Final Summary: Generated: {generated_count}, Skipped: {skipped_count}, Failed: {failed_count}")
        
        return {"all_problems": results}
    
    def _convert_to_json_serializable(self, obj):
        """
        Convert numpy/pandas types to JSON-serializable Python types.
        
        Args:
            obj: Object that may contain numpy/pandas types
        
        Returns:
            JSON-serializable version of the object
        """
        import numpy as np
        
        if isinstance(obj, dict):
            return {str(k): self._convert_to_json_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_to_json_serializable(item) for item in obj]
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif hasattr(obj, 'item'):  # numpy scalar
            return obj.item()
        elif isinstance(obj, (int, float, str, bool, type(None))):
            return obj
        else:
            return str(obj)
    
    def save_generation_report(self, results: dict = None) -> str:
        """
        Save a detailed report of the generation process including timing data.
        
        Args:
            results: Generation results dictionary
        
        Returns:
            Path to the saved report file
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        report_file = os.path.join(self.base_output_dir, f"generation_report_{timestamp}.json")
        
        # Calculate timing statistics
        timing_stats = {}
        for api_provider, times in self.generation_times.items():
            if times:
                timing_stats[api_provider] = {
                    'total_time': float(sum(times)),
                    'average_time': float(sum(times) / len(times)),
                    'min_time': float(min(times)),
                    'max_time': float(max(times)),
                    'count': int(len(times))
                }
        
        # Convert all data to JSON-serializable format
        report_data = {
            'timestamp': timestamp,
            'total_problems_in_dataset': int(len(self.df)),
            'generation_stats': self._convert_to_json_serializable(dict(self.generation_stats)),
            'failed_generations': self._convert_to_json_serializable(self.failed_generations),
            'timing_statistics': self._convert_to_json_serializable(timing_stats),
            'results': self._convert_to_json_serializable(results or {})
        }
        
        with open(report_file, 'w') as f:
            json.dump(report_data, f, indent=2)
        
        print(f"📊 Generation report saved to: {report_file}")
        return report_file
    
    def print_summary(self):
        """Print a summary of the generation process including timing data."""
        print("\n" + "="*60)
        print("📊 GENERATION SUMMARY")
        print("="*60)
        
        total_generated = sum(self.generation_stats.values())
        total_failed = len(self.failed_generations)
        
        print(f"Total images generated: {total_generated}")
        print(f"Total failed: {total_failed}")
        print(f"Success rate: {total_generated/(total_generated + total_failed)*100:.1f}%")
        
        print("\nBy API Provider and Type/Level:")
        for combination, count in self.generation_stats.items():
            print(f"  {combination}: {count} images")
        
        print("\n⏱️  TIMING STATISTICS:")
        print("-" * 40)
        for api_provider, times in self.generation_times.items():
            if times:
                avg_time = sum(times) / len(times)
                min_time = min(times)
                max_time = max(times)
                total_time = sum(times)
                print(f"{api_provider.upper()}:")
                print(f"  - Average: {avg_time:.2f}s")
                print(f"  - Min: {min_time:.2f}s")
                print(f"  - Max: {max_time:.2f}s")
                print(f"  - Total: {total_time:.2f}s")
                print(f"  - Count: {len(times)}")
                print()
        
        if self.failed_generations:
            print(f"\nFailed generations: {len(self.failed_generations)}")
            for failure in self.failed_generations[:5]:  # Show first 5 failures
                print(f"  - {failure['api_provider']} {failure['type']} {failure['level']}: {failure['error']}")

def main():
    """
    Main function to demonstrate iterative image generation with multiple APIs.
    """
    # Check if API keys are set
    missing_keys = []
    if not os.getenv("OPENAI_API_KEY"):
        missing_keys.append("OPENAI_API_KEY")
    if not os.getenv("GEMINI_API_KEY"):
        missing_keys.append("GEMINI_API_KEY")
    # Optional: Check for Flux API key
    if not os.getenv("BFL_API_KEY"):
        print("⚠️  BFL_API_KEY not set - Flux API will not be available")

    if missing_keys:
        print("❌ Please set the following environment variables:")
        for key in missing_keys:
            print(f"export {key}='your-{key.lower()}-api-key'")
        return
    
    # Initialize generator
    csv_path = "/Users/shangwu/Downloads/questions/competition_math_data/subsets/balanced_500.csv"
    generator = IterativeImageGenerator(csv_path)
    
    # Choose generation method:
    
    # Option 1: Generate ALL problems with single API (recommended for large datasets)
    print("🚀 Generating ALL problems with single API...")
    test_results = generator.generate_all_problems_single_api(
        api_provider='wan22',  # Change to 'openai', 'gemini', 'stability', 'flux', 'flux-fal', 'wan22', or 'kling' as needed
        skip_existing=True
    )

    # Option 2: Generate ALL problems with multiple APIs (slower but compares APIs)
    # print("🚀 Generating ALL problems with multiple APIs...")
    # test_results = generator.generate_test_set(
    #     api_providers=['gemini', 'flux-fal'],  # Add more APIs as needed: ['openai', 'gemini', 'stability', 'flux', 'flux-fal', 'wan22', 'kling']
    #     skip_existing=True
    # )
    
    # Option 3: Generate by priority with specific APIs (including Kling AI)
    # print("🚀 Generating priority problems with Kling AI...")
    # priority_combinations = [
    #     ('Algebra', 'Level 1'),
    #     ('Geometry', 'Level 2'),
    #     ('Number Theory', 'Level 3')
    # ]
    # test_results = generator.generate_by_priority(
    #     priority_combinations=priority_combinations,
    #     problems_per_combination=2,
    #     api_providers=['kling'],  # Test with Kling AI API
    #     skip_existing=True
    # )
    
    # Save report and print summary
    generator.save_generation_report(test_results)
    generator.print_summary()
    
    print(f"\n Images organized in: {generator.base_output_dir}")
    print("Directory structure:")
    print("  generated_images/")
    print("    ├── openai/")
    print("    │   ├── algebra/")
    print("    │   │   ├── level_1/")
    print("    │   │   │   ├── 000_algebra_level_1_123.png")
    print("    │   │   │   └── 001_algebra_level_1_456.png")
    print("    │   │   └── level_2/")
    print("    │   │       └── 000_algebra_level_2_789.png")
    print("    │   └── geometry/")
    print("    │       └── level_1/")
    print("    │           └── 000_geometry_level_1_321.png")
    print("    ├── gemini/")
    print("    │   ├── algebra/")
    print("    │   │   └── level_1/")
    print("    │   │       └── 000_algebra_level_1_123.png")
    print("    │   └── geometry/")
    print("    │       └── level_1/")
    print("    │           └── 000_geometry_level_1_321.png")
    print("    └── flux/")
    print("        ├── algebra/")
    print("        │   └── level_1/")
    print("        │       └── 000_algebra_level_1_123.png")
    print("        └── geometry/")
    print("            └── level_1/")
    print("                └── 000_geometry_level_1_321.png")

if __name__ == "__main__":
    main()