import pandas as pd
import requests
import json
import os
import time
from typing import List, Dict, Optional
import random
from datetime import datetime
from openai import OpenAI
import base64
from google import genai
from google.genai import types
from PIL import Image
from io import BytesIO

# Import fal_client for WAN 2.2 model (install with: pip install fal-client)
try:
    import fal_client
    FAL_CLIENT_AVAILABLE = True
except ImportError:
    FAL_CLIENT_AVAILABLE = False
    print("⚠️  fal_client not available. Install with: pip install fal-client")

# Import JWT for Kling AI authentication (install with: pip install PyJWT)
try:
    import jwt
    JWT_AVAILABLE = True
except ImportError:
    JWT_AVAILABLE = False
    print("⚠️  PyJWT not available. Install with: pip install PyJWT")

class T2IModelTester:
    """
    A comprehensive system for testing Text-to-Image models with math problems
    from the competition math dataset.
    """
    
    def __init__(self, csv_path: str, output_dir: str = "t2i_results"):
        """
        Initialize the T2I model tester.
        
        Args:
            csv_path: Path to the balanced_sample_1000.csv file
            output_dir: Directory to save generated images and results
        """
        self.csv_path = csv_path
        self.output_dir = output_dir
        self.results = []
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Load the dataset
        self.df = pd.read_csv(csv_path)
        
        # Initialize API clients
        self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        self.gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
        
        # System prompts for different T2I models
        self.system_prompts = {
            "blackboard": self._get_blackboard_prompt(),
            "whiteboard": self._get_whiteboard_prompt(),
            "notebook": self._get_notebook_prompt(),
            "paper": self._get_paper_prompt(),
            "digital": self._get_digital_prompt()
        }
    
    def _get_blackboard_prompt(self) -> str:
        """System prompt for blackboard-style math problems."""
        return """Create a clean, professional blackboard that shows a math problem and its solution. 
        The blackboard should have:
        - A dark green or black background
        - White or yellow chalk text
        - Clear, readable mathematical notation
        - Proper formatting with clear sections for problem and solution
        - A clean, academic appearance
        - No clutter or distractions
        
        Render the mathematical expressions exactly as provided, including:
        - Proper superscripts and subscripts
        - Mathematical symbols (√, π, ∑, etc.)
        - Fractions and equations
        - Clear step-by-step solutions"""
    
    def _get_whiteboard_prompt(self) -> str:
        """System prompt for whiteboard-style math problems."""
        return """Create a clean whiteboard that displays a math problem and its solution.
        The whiteboard should have:
        - A bright white background
        - Black or blue marker text
        - Clear, professional handwriting style
        - Well-organized layout with clear sections
        - Proper mathematical notation
        - Clean, modern appearance
        
        Ensure all mathematical expressions are rendered clearly and accurately."""
    
    def _get_notebook_prompt(self) -> str:
        """System prompt for notebook-style math problems."""
        return """Create a neat notebook page showing a math problem and its solution.
        The notebook should have:
        - A lined or grid paper background
        - Handwritten text in blue or black ink
        - Clear, organized layout
        - Proper mathematical notation
        - A student-like but neat handwriting style
        - Clear problem and solution sections"""
    
    def _get_paper_prompt(self) -> str:
        """System prompt for paper-style math problems."""
        return """Create a clean sheet of paper with a math problem and its solution.
        The paper should have:
        - A white or off-white background
        - Black ink text
        - Professional, typed appearance
        - Clear mathematical notation
        - Well-formatted equations
        - Clean, academic presentation"""
    
    def _get_digital_prompt(self) -> str:
        """System prompt for digital-style math problems."""
        return """Create a modern digital interface showing a math problem and its solution.
        The interface should have:
        - A clean, modern design
        - High contrast text
        - Professional mathematical typesetting
        - Clear visual hierarchy
        - Modern UI elements
        - Excellent readability"""
    
    def format_math_problem(self, problem: str, solution: str, level: str, problem_type: str) -> str:
        """
        Format a math problem for T2I generation.
        
        Args:
            problem: The math problem text
            solution: The solution text
            level: The difficulty level
            problem_type: The type of problem (Algebra, Geometry, etc.)
        
        Returns:
            Formatted prompt for T2I model
        """
        # Clean up the text
        problem = problem.strip()
        solution = solution.strip()
        
        # Create a comprehensive prompt
        prompt = f"""Title: {problem_type} Problem - {level}

Problem: {problem}

Solution: {solution}

Instructions: Create a clean, professional presentation of this math problem and its complete solution. 
Ensure all mathematical notation is rendered clearly and accurately."""
        
        return prompt
    
    def call_openai(self, prompt: str, model: str = "gpt-image-1", size: str = "1024x1024") -> Optional[Dict]:
        """
        Call OpenAI's ChatGPT image generation API to generate an image.
        
        Args:
            prompt: The text prompt for image generation
            model: gpt-image-1, dall-e-3
            size: Image size (1024x1024, 1792x1024, or 1024x1792)
        
        Returns:
            API response or None if failed
        """
        try:
            if model == "gpt-image-1":
                response = self.openai_client.images.generate(
                    model=model,
                    prompt=prompt,
                    n=1,
                    size=size,
                )
            else:
                response = self.openai_client.images.generate(
                    model=model,
                    prompt=prompt,
                    n=1,
                    size=size,
                    response_format="b64_json"
                )
            return response
        except Exception as e:
            print(f"❌ OpenAI API Error: {e}")
            print(f"   Model: {model}")
            print(f"   Size: {size}")
            print(f"   Prompt length: {len(prompt)} characters")
            return None
    
    def call_gemini(self, prompt: str, model: str = "gemini-2.5-flash-image-preview") -> Optional[Dict]:
        """
        Call Google's Gemini API to generate an image.
        
        Args:
            prompt: The text prompt for image generation
            model: Gemini model to use
        
        Returns:
            API response or None if failed
        """
        try:
            response = self.gemini_client.models.generate_content(
                model=model,
                contents=[prompt]
            )
            return response
        except Exception as e:
            print(f"❌ Gemini API Error: {e}")
            print(f"   Model: {model}")
            print(f"   Prompt length: {len(prompt)} characters")
            return None
    
    def call_stability_ai(self, prompt: str, model: str = "sd3.5-ultra", aspect_ratio: str = "1:1",
                          negative_prompt: str = "", output_format: str = "png") -> Optional[Dict]:
        """
        Call Stability AI v2beta API to generate an image using SD3 models.

        Args:
            prompt: The text prompt for image generation
            model: Stability AI model (sd3, sd3-turbo)
            aspect_ratio: Image aspect ratio (1:1, 16:9, 9:16, 4:3, 3:2, etc.)
            negative_prompt: Optional negative prompt
            output_format: Image format (png, jpeg)

        Returns:
            API response with image data or None if failed
        """
        api_key = os.getenv("STABILITY_API_KEY")
        if not api_key:
            print("❌ STABILITY_API_KEY not found in environment variables")
            return None

        print(f"🎨 Generating image with Stability AI {model.upper()}...")
        print(f"   Aspect ratio: {aspect_ratio}, Format: {output_format}")

        # Use v2beta API endpoint for SD3
        if model == 'sd3.5-ultra':
            url = "https://api.stability.ai/v2beta/stable-image/generate/ultra"
            data = {
                "prompt": prompt,
                "aspect_ratio": aspect_ratio,
                "output_format": output_format
            }
        else:
            url = "https://api.stability.ai/v2beta/stable-image/generate/sd3"
            # Prepare form data (v2beta uses multipart/form-data, not JSON)
            data = {
                "prompt": prompt,
                "model": model,
                "aspect_ratio": aspect_ratio,
                "output_format": output_format
            }

        headers = {
            "Authorization": f"Bearer {api_key}",
            "Accept": "image/*"
        }

        if negative_prompt:
            data["negative_prompt"] = negative_prompt

        # Add empty files parameter (required by the API)
        files = {"none": ""}

        try:
            response = requests.post(url, headers=headers, files=files, data=data, timeout=60)
            response.raise_for_status()

            # Check for content filtering
            finish_reason = response.headers.get("finish-reason")
            seed = response.headers.get("seed")

            if finish_reason == "CONTENT_FILTERED":
                print("❌ Content filtered by NSFW classifier")
                return None

            if response.status_code == 200:
                # Save image data and return metadata
                image_data = response.content

                print(f"✅ Stability AI image generated successfully")
                if seed:
                    print(f"   Seed: {seed}")

                return {
                    "image_data": image_data,
                    "seed": seed,
                    "finish_reason": finish_reason,
                    "model_used": model,
                    "format": output_format
                }
            else:
                print(f"❌ Unexpected status code: {response.status_code}")
                return None

        except requests.exceptions.RequestException as e:
            print(f"❌ Stability AI API Error: {e}")
            return None

    def call_flux(self, prompt: str, model: str = "flux-pro-1.1", width: int = 1024, height: int = 1024, retry_on_500: bool = True) -> Optional[Dict]:
        """
        Call Black Forest Labs Flux API to generate an image.

        Args:
            prompt: The text prompt for image generation
            model: Flux model to use (flux-pro-1.1, flux-dev, flux-schnell)
            width: Image width (default 1024)
            height: Image height (default 1024)
            retry_on_500: Whether to retry with alternative models on 500 errors

        Returns:
            API response with image URL or None if failed
        """
        api_key = os.getenv("BFL_API_KEY")
        if not api_key:
            print("❌ BFL_API_KEY not found in environment variables")
            return None

        # Map model names to API endpoints
        model_endpoints = {
            "flux-pro-1.1": "https://api.bfl.ml/v1/flux-pro-1.1",
            "flux-dev": "https://api.bfl.ml/v1/flux-dev",
            "flux-schnell": "https://api.bfl.ml/v1/flux-schnell"
        }

        # Define model fallback order for reliability
        model_fallback_order = {
            "flux-pro-1.1": ["flux-dev", "flux-schnell"],
            "flux-dev": ["flux-schnell", "flux-pro-1.1"],
            "flux-schnell": ["flux-dev", "flux-pro-1.1"]
        }

        models_to_try = [model]
        if retry_on_500:
            models_to_try.extend(model_fallback_order.get(model, []))

        for current_model in models_to_try:
            url = model_endpoints.get(current_model)
            if not url:
                print(f"❌ Unknown Flux model: {current_model}")
                continue

            if current_model != model:
                print(f"   Trying fallback model: {current_model}")

            headers = {
                "accept": "application/json",
                "x-key": api_key,
                "Content-Type": "application/json"
            }

            data = {
                "prompt": prompt,
                "width": width,
                "height": height
            }

            # Retry logic with exponential backoff
            max_retries = 3
            retry_delay = 1  # Start with 1 second

            for retry in range(max_retries):
                try:
                    # Make the initial request
                    response = requests.post(url, headers=headers, json=data, timeout=30)

                    # Check for 500 error
                    if response.status_code == 500:
                        if retry < max_retries - 1:
                            print(f"   ⚠️  Got 500 error from {current_model}, retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                            time.sleep(retry_delay)
                            retry_delay *= 2  # Exponential backoff
                            continue
                        else:
                            print(f"   ❌ 500 error persists for {current_model} after {max_retries} attempts")
                            break  # Try next model

                    response.raise_for_status()
                    request_data = response.json()

                    # Get request ID and polling URL
                    request_id = request_data.get('id')
                    polling_url = request_data.get('polling_url')

                    if not polling_url:
                        print(f"❌ No polling URL returned from Flux API ({current_model})")
                        break  # Try next model

                    # Poll for results
                    max_attempts = 60  # Maximum 30 seconds (60 * 0.5s)
                    for attempt in range(max_attempts):
                        time.sleep(0.5)

                        result = requests.get(
                            polling_url,
                            headers={
                                "accept": "application/json",
                                "x-key": api_key
                            },
                            timeout=10
                        )
                        result.raise_for_status()
                        result_data = result.json()

                        status = result_data.get('status')

                        if status == "Ready":
                            # Image generation successful
                            image_url = result_data.get('result', {}).get('sample')
                            if image_url:
                                print(f"✅ Flux image generated successfully with {current_model}")
                                return {"image_url": image_url, "request_id": request_id, "model_used": current_model}
                            else:
                                print(f"❌ No image URL in result from {current_model}")
                                break  # Try next model

                        elif status in ["Error", "Failed"]:
                            print(f"❌ Flux generation failed with {current_model}: {result_data}")
                            break  # Try next model

                        # Still processing, continue polling
                        if attempt % 10 == 0:  # Print status every 5 seconds
                            print(f"   Status: {status} (attempt {attempt + 1}/{max_attempts})")

                    if status not in ["Ready", "Error", "Failed"]:
                        print(f"❌ Flux generation timed out with {current_model}")
                        break  # Try next model

                    # If we got here and status was Ready, we already returned
                    # Otherwise, break to try next model
                    break

                except requests.exceptions.Timeout:
                    print(f"   ⚠️  Request timeout for {current_model}")
                    if retry < max_retries - 1:
                        print(f"   Retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                        continue
                    break  # Try next model

                except requests.exceptions.RequestException as e:
                    if "500" in str(e) and retry < max_retries - 1:
                        print(f"   ⚠️  500 error: {e}")
                        print(f"   Retrying in {retry_delay}s... (attempt {retry+1}/{max_retries})")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                        continue
                    else:
                        print(f"❌ Flux API Error with {current_model}: {e}")
                        break  # Try next model

        # All models failed
        print("❌ All Flux models failed to generate image")
        return None

    def call_kling(self, prompt: str, aspect_ratio: str = "1:1", image_count: int = 1,
                   model_name: str = "kling-v1-5", camera_type: Optional[str] = None,
                   retry_on_rate_limit: bool = True) -> Optional[Dict]:
        """
        Call Kling AI API to generate an image with rate limiting handling.

        Args:
            prompt: The text prompt for image generation
            aspect_ratio: Aspect ratio (1:1, 9:16, 16:9, 2:3, 3:2, 3:4, 4:3)
            image_count: Number of images to generate (1-9)
            model_name: Model to use (kling-v1-5, kolors)
            camera_type: Optional camera movement type
            retry_on_rate_limit: Whether to retry when hitting rate limits

        Returns:
            API response with image URL or None if failed
        """
        if not JWT_AVAILABLE:
            print("❌ PyJWT not available. Install with: pip install PyJWT")
            return None

        # Check for Kling API keys
        access_key = os.getenv("KLING_ACCESS_KEY")
        secret_key = os.getenv("KLING_SECRET_KEY")

        if not access_key or not secret_key:
            print("❌ KLING_ACCESS_KEY or KLING_SECRET_KEY not found in environment variables")
            return None

        print(f"🎨 Generating image with Kling AI...")
        print(f"   Aspect ratio: {aspect_ratio}, Model: {model_name}")

        try:
            # Generate JWT token for authentication
            current_time = int(time.time())
            payload = {
                "iss": access_key,
                "exp": current_time + 1800,  # Expires in 30 minutes
                "nbf": current_time - 5,     # Not before: 5 seconds ago
                "iat": current_time          # Issued at time
            }

            # Ensure we have the latest version of PyJWT that returns string
            jwt_token = jwt.encode(
                payload,
                secret_key,
                algorithm="HS256"
            )

            # For newer PyJWT versions, ensure it's a string
            if isinstance(jwt_token, bytes):
                jwt_token = jwt_token.decode('utf-8')

            print(f"   🔑 Generated JWT token (first 20 chars): {jwt_token[:20]}...")

            # Validate token format
            if not jwt_token or len(jwt_token) < 20:
                print("❌ JWT token generation failed or too short")
                return None

            # Create image generation task
            url = "https://api.klingai.com/v1/images/generations"
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {jwt_token}",
                "User-Agent": "Python/Kling-API-Client"
            }

            data = {
                "model_name": model_name,
                "prompt": prompt,
                "aspect_ratio": aspect_ratio,
                "image_count": image_count
            }

            if camera_type:
                data["camera_type"] = camera_type

            print(f"   📝 Request data: model={model_name}, aspect_ratio={aspect_ratio}")
            print(f"   🔗 API URL: {url}")
            print(f"   📋 Authorization header present: {'Authorization' in headers}")
            print(f"   📝 Payload keys: {list(data.keys())}")

            # Submit generation request with rate limit handling
            max_retries = 3 if retry_on_rate_limit else 1
            retry_delay = 30  # Start with 30 seconds for rate limit

            for retry_attempt in range(max_retries):
                try:
                    response = requests.post(url, headers=headers, json=data, timeout=30)

                    # Handle rate limiting (429)
                    if response.status_code == 429:
                        if retry_attempt < max_retries - 1 and retry_on_rate_limit:
                            # Check rate limit headers for better retry timing
                            retry_after = response.headers.get('Retry-After')
                            if retry_after:
                                wait_time = int(retry_after)
                            else:
                                wait_time = retry_delay * (2 ** retry_attempt)  # Exponential backoff

                            print(f"   ⚠️  Rate limited. Waiting {wait_time}s before retry (attempt {retry_attempt + 1}/{max_retries})")
                            time.sleep(wait_time)
                            continue
                        else:
                            print("❌ Rate limit exceeded. Please wait before trying again.")
                            return None

                    response.raise_for_status()
                    result = response.json()
                    break  # Success, exit retry loop

                except requests.exceptions.RequestException as e:
                    if retry_attempt < max_retries - 1:
                        print(f"   ⚠️  Request failed: {e}. Retrying in {retry_delay}s...")
                        time.sleep(retry_delay)
                        retry_delay *= 2
                        continue
                    else:
                        raise e

            if result.get('code') == 0 and 'data' in result:
                task_id = result['data'].get('task_id')

                if task_id:
                    # Poll for task completion
                    max_attempts = 60
                    for attempt in range(max_attempts):
                        time.sleep(2)  # Wait 2 seconds between checks

                        # Check task status
                        status_url = f"https://api.klingai.com/v1/images/generations/{task_id}"
                        status_response = requests.get(status_url, headers=headers, timeout=30)
                        status_response.raise_for_status()
                        status_result = status_response.json()

                        if status_result.get('code') == 0 and 'data' in status_result:
                            task_data = status_result['data']

                            if task_data.get('task_status') == 'succeed':
                                # Get the first generated image
                                works = task_data.get('task_result', {}).get('works', [])
                                if works and len(works) > 0:
                                    image_url = works[0].get('resource', {}).get('resource')
                                    if image_url:
                                        print(f"✅ Kling AI image generated successfully")
                                        return {
                                            "image_url": image_url,
                                            "task_id": task_id,
                                            "model_used": "kling"
                                        }
                                break
                            elif task_data.get('task_status') == 'failed':
                                print(f"❌ Kling AI generation failed")
                                return None

                        # Show progress
                        if attempt % 10 == 0:
                            print(f"   Waiting for generation... (attempt {attempt + 1}/{max_attempts})")

                    print("❌ Kling AI generation timed out")
                    return None
                else:
                    print("❌ No task ID returned from Kling AI")
                    return None
            else:
                print(f"❌ Kling AI API error: {result.get('message', 'Unknown error')}")
                return None

        except Exception as e:
            print(f"❌ Kling AI Error: {e}")
            return None

    def call_wan22(self, prompt: str, num_inference_steps: int = 35, image_size: str = "square",
                   guidance_scale: float = 3.5, seed: Optional[int] = None) -> Optional[Dict]:
        """
        Call Fal.ai WAN 2.2 model to generate an image.

        Args:
            prompt: The text prompt for image generation
            num_inference_steps: Number of inference steps (default 27)
            image_size: Image size (square_hd, square, portrait, landscape, landscape_4_3, portrait_4_3)
            guidance_scale: Guidance scale for prompt adherence (default 3.5)
            seed: Random seed for reproducibility (optional)

        Returns:
            API response with image URL or None if failed
        """
        if not FAL_CLIENT_AVAILABLE:
            print("❌ fal_client not available. Install with: pip install fal-client")
            return None

        # Check for FAL API key
        api_key = os.getenv("FAL_KEY")
        if not api_key:
            print("❌ FAL_KEY not found in environment variables")
            return None

        print(f"🎨 Generating image with WAN 2.2...")
        print(f"   Steps: {num_inference_steps}, Size: {image_size}")

        try:
            # Prepare arguments
            arguments = {
                "prompt": prompt,
                "num_inference_steps": num_inference_steps,
                "image_size": image_size,
                "guidance_scale": guidance_scale
            }

            if seed is not None:
                arguments["seed"] = seed

            # Call WAN 2.2 model via fal_client
            result = fal_client.subscribe(
                "fal-ai/wan/v2.2-a14b/text-to-image",
                arguments=arguments
            )

            if result and 'image' in result:
                print(f"✅ WAN 2.2 image generated successfully")
                return {
                    "image_url": result['image']['url'],
                    "seed": result.get('seed'),
                    "model_used": "wan-2.2"
                }
            else:
                print("❌ No image returned from WAN 2.2 API")
                return None

        except Exception as e:
            print(f"❌ WAN 2.2 API Error: {e}")
            return None

    def call_flux_fal(self, prompt: str, num_images: int = 1, image_size: str = "square",
                      guidance_scale: float = 3.5, seed: Optional[int] = None) -> Optional[Dict]:
        """
        Call fal.ai FLUX Pro API to generate an image.

        Args:
            prompt: The text prompt for image generation
            num_images: Number of images to generate (max 4)
            image_size: Aspect ratio (landscape_4_3, portrait_3_4, square, etc.)
            guidance_scale: How closely the model follows the prompt (default 3.5)
            seed: Optional seed for reproducible generation

        Returns:
            API response with image URL or None if failed
        """
        api_key = os.getenv("FAL_KEY")
        if not api_key:
            print("❌ FAL_KEY not found in environment variables")
            print("   Please set: export FAL_KEY='your-fal-api-key'")
            return None

        print(f"🎨 Generating image with fal.ai FLUX Pro...")
        print(f"   Image size: {image_size}, Guidance: {guidance_scale}")

        try:
            import fal_client

            # Configure fal_client with API key if needed
            fal_client.api_key = api_key

            # Prepare arguments for FLUX Pro
            arguments = {
                "prompt": prompt,
                "num_images": num_images,
                "image_size": image_size,
                "guidance_scale": guidance_scale,
                "safety_tolerance": "2",  # Default safety level
                "output_format": "jpeg"
            }

            if seed is not None:
                arguments["seed"] = seed

            # Call FLUX Pro model via fal_client
            result = fal_client.subscribe(
                "fal-ai/flux-pro/kontext/text-to-image",
                arguments=arguments
            )

            # Check if we got images back
            if result and 'images' in result and len(result['images']) > 0:
                # Get the first image
                image_data = result['images'][0]

                # Handle different response formats
                if isinstance(image_data, dict) and 'url' in image_data:
                    image_url = image_data['url']
                elif isinstance(image_data, str):
                    image_url = image_data
                else:
                    print(f"❌ Unexpected image data format: {type(image_data)}")
                    return None

                print(f"✅ fal.ai FLUX Pro image generated successfully")
                return {
                    "image_url": image_url,
                    "seed": result.get('seed'),
                    "model_used": "flux-pro-fal",
                    "has_nsfw_concepts": result.get('has_nsfw_concepts', [])
                }
            else:
                print("❌ No images returned from fal.ai FLUX Pro API")
                print(f"   Response: {result}")
                return None

        except ImportError:
            print("❌ fal_client not installed. Please install with: pip install fal-client")
            return None
        except Exception as e:
            print(f"❌ fal.ai FLUX Pro API Error: {e}")
            return None

    def save_image_from_url(self, image_url: str, filename: str) -> bool:
        """
        Download and save an image from a URL.

        Args:
            image_url: URL of the image to download
            filename: Local filename to save the image

        Returns:
            True if successful, False otherwise
        """
        try:
            response = requests.get(image_url, timeout=30)
            response.raise_for_status()
            with open(filename, "wb") as f:
                f.write(response.content)
            return True
        except Exception as e:
            print(f"❌ Failed to download image from URL: {e}")
            return False

    def save_image_from_base64(self, image_data: str, filename: str) -> bool:
        """
        Save an image from base64 data.

        Args:
            image_data: Base64 encoded image data
            filename: Local filename to save the image

        Returns:
            True if successful, False otherwise
        """
        try:
            image_bytes = base64.b64decode(image_data)
            with open(filename, "wb") as f:
                f.write(image_bytes)
            return True
        except Exception:
            return False
    
    def save_image_from_bytes(self, image_bytes: bytes, filename: str) -> bool:
        """
        Save an image from bytes data.
        
        Args:
            image_bytes: Image data as bytes
            filename: Local filename to save the image
        
        Returns:
            True if successful, False otherwise
        """
        try:
            with open(filename, "wb") as f:
                f.write(image_bytes)
            return True
        except Exception:
            return False
    
    def test_single_problem(self, index: int, style: str = "blackboard", 
                          api_provider: str = "openai", save_image: bool = True) -> Dict:
        """
        Test a single math problem with T2I model.
        
        Args:
            index: Index of the problem in the dataset
            style: Visual style (blackboard, whiteboard, notebook, paper, digital)
            api_provider: API provider (openai, gemini, stability)
            save_image: Whether to save the generated image
        
        Returns:
            Dictionary with test results
        """
        if index >= len(self.df):
            return {"error": "Index out of range"}
        
        row = self.df.iloc[index]
        problem = row['problem']
        solution = row['solution']
        level = row['level']
        problem_type = row['type']
        
        # Format the problem
        formatted_prompt = self.format_math_problem(problem, solution, level, problem_type)
        
        # Add system prompt for the chosen style
        system_prompt = self.system_prompts.get(style, self.system_prompts["blackboard"])
        full_prompt = f"{system_prompt}\n\n{formatted_prompt}"
        
        # Call the appropriate API
        start_time = time.time()

        if api_provider == "openai":
            response = self.call_openai(full_prompt)
        elif api_provider == "gemini":
            response = self.call_gemini(full_prompt)
            time.sleep(1)
        elif api_provider == "stability":
            response = self.call_stability_ai(full_prompt)
        elif api_provider == "flux":
            response = self.call_flux(full_prompt)
        elif api_provider == "flux-fal" or api_provider == "flux_pro_fal":
            response = self.call_flux_fal(full_prompt)
        elif api_provider == "wan22" or api_provider == "wan-2.2":
            response = self.call_wan22(full_prompt)
        elif api_provider == "kling":
            response = self.call_kling(full_prompt)
        else:
            return {"error": f"Unknown API provider: {api_provider}"}

        generation_time = time.time() - start_time
        
        if not response:
            return {"error": "API call failed"}
        
        # Save image if requested
        if save_image:
            # Clean problem type and level for filename
            clean_type = problem_type.lower().replace(" ", "_")
            clean_level = level.lower().replace(" ", "_")
            filename = f"{self.output_dir}/{clean_type}_{clean_level}_{index}.png"
            
            if api_provider == "openai":
                # Save image from base64 data
                image_data = response.data[0].b64_json
                self.save_image_from_base64(image_data, filename)
            elif api_provider == "gemini":
                # Save image from bytes data
                for part in response.candidates[0].content.parts:
                    if part.inline_data is not None:
                        image_bytes = part.inline_data.data
                        self.save_image_from_bytes(image_bytes, filename)
                        break
            elif api_provider == "stability":
                # Handle Stability AI v2beta response (raw bytes)
                image_data = response.get('image_data')
                if image_data:
                    self.save_image_from_bytes(image_data, filename)
                else:
                    print("❌ No image data found in Stability AI response")
            elif api_provider == "flux-fal":
                # Handle Flux API response (URL)
                image_url = response.get('image_url')
                if image_url:
                    self.save_image_from_url(image_url, filename)
                else:
                    print("❌ No image URL found in Flux response")
            elif api_provider == "wan22" or api_provider == "wan-2.2":
                # Handle WAN 2.2 API response (URL)
                image_url = response.get('image_url')
                if image_url:
                    self.save_image_from_url(image_url, filename)
                else:
                    print("❌ No image URL found in WAN 2.2 response")
            elif api_provider == "kling":
                # Handle Kling AI response (URL)
                image_url = response.get('image_url')
                if image_url:
                    self.save_image_from_url(image_url, filename)
                else:
                    print("❌ No image URL found in Kling response")
        
        # Store results
        result = {
            "index": index,
            "problem": problem,
            "solution": solution,
            "level": level,
            "type": problem_type,
            "style": style,
            "api_provider": api_provider,
            "generation_time": generation_time,
            "image_data": "base64_data",
            "success": True,
            "timestamp": datetime.now().isoformat()
        }
        
        self.results.append(result)
        return result
    
    def test_batch(self, indices: List[int], style: str = "blackboard", 
                   api_provider: str = "openai", delay: float = 1.0) -> List[Dict]:
        """
        Test multiple problems in batch.
        
        Args:
            indices: List of problem indices to test
            style: Visual style
            api_provider: API provider
            delay: Delay between API calls (seconds)
        
        Returns:
            List of test results
        """
        results = []
        
        for i, index in enumerate(indices):
            result = self.test_single_problem(index, style, api_provider)
            results.append(result)
            
            if i < len(indices) - 1:  # Don't delay after the last item
                time.sleep(delay)
        
        return results
    
    def test_random_sample(self, n: int = 10, style: str = "blackboard", 
                          api_provider: str = "openai") -> List[Dict]:
        """
        Test a random sample of problems.
        
        Args:
            n: Number of problems to test
            style: Visual style
            api_provider: API provider
        
        Returns:
            List of test results
        """
        indices = random.sample(range(len(self.df)), min(n, len(self.df)))
        return self.test_batch(indices, style, api_provider)
    
    def save_results(self, filename: str = None) -> str:
        """
        Save test results to JSON file.
        
        Args:
            filename: Output filename (optional)
        
        Returns:
            Path to saved file
        """
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"{self.output_dir}/t2i_test_results_{timestamp}.json"
        
        with open(filename, 'w') as f:
            json.dump(self.results, f, indent=2)
        
        return filename
    
    def generate_report(self) -> str:
        """
        Generate a summary report of the test results.
        
        Returns:
            Report text
        """
        if not self.results:
            return "No results to report."
        
        successful_tests = [r for r in self.results if r.get('success', False)]
        failed_tests = [r for r in self.results if not r.get('success', False)]
        
        # Group by level and type
        level_counts = {}
        type_counts = {}
        style_counts = {}
        api_counts = {}
        
        for result in successful_tests:
            level = result.get('level', 'Unknown')
            problem_type = result.get('type', 'Unknown')
            style = result.get('style', 'Unknown')
            api_provider = result.get('api_provider', 'Unknown')
            
            level_counts[level] = level_counts.get(level, 0) + 1
            type_counts[problem_type] = type_counts.get(problem_type, 0) + 1
            style_counts[style] = style_counts.get(style, 0) + 1
            api_counts[api_provider] = api_counts.get(api_provider, 0) + 1
        
        report = f"""
# T2I Model Test Report

## Summary
- Total tests: {len(self.results)}
- Successful: {len(successful_tests)}
- Failed: {len(failed_tests)}
- Success rate: {len(successful_tests)/len(self.results)*100:.1f}%

## Results by Level
{chr(10).join([f"- {level}: {count}" for level, count in sorted(level_counts.items())])}

## Results by Problem Type
{chr(10).join([f"- {ptype}: {count}" for ptype, count in sorted(type_counts.items())])}

## Results by Style
{chr(10).join([f"- {style}: {count}" for style, count in sorted(style_counts.items())])}

## Results by API Provider
{chr(10).join([f"- {api}: {count}" for api, count in sorted(api_counts.items())])}

## Average Generation Time
{sum(r.get('generation_time', 0) for r in successful_tests) / len(successful_tests):.2f} seconds

## Failed Tests
{chr(10).join([f"- Index {r.get('index', '?')}: {r.get('error', 'Unknown error')}" for r in failed_tests])}
"""
        return report

def main():
    """
    Main function to demonstrate the T2I testing system.
    """
    # Initialize the tester
    csv_path = "/Users/shangwu/Downloads/questions/competition_math_data/balanced_sample_1000.csv"
    tester = T2IModelTester(csv_path)

    # Test a few problems with different APIs
    print("Testing with OpenAI...")
    tester.test_random_sample(n=2, style="blackboard", api_provider="openai")

    print("\nTesting with Gemini...")
    tester.test_random_sample(n=2, style="blackboard", api_provider="gemini")

    print("\nTesting with Flux...")
    tester.test_random_sample(n=2, style="blackboard", api_provider="flux")

    print("\nTesting with WAN 2.2...")
    tester.test_random_sample(n=2, style="blackboard", api_provider="wan22")

    print("\nTesting with Kling AI...")
    tester.test_random_sample(n=2, style="blackboard", api_provider="kling")

    # Save results
    results_file = tester.save_results()

    # Generate report
    report = tester.generate_report()
    print(report)

    # Save report
    report_file = f"{tester.output_dir}/test_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
    with open(report_file, 'w') as f:
        f.write(report)

if __name__ == "__main__":
    main()