import sys
import os
import pandas as pd
import numpy as np
import math
import asyncio
import time
import json

from openai import AsyncOpenAI, BadRequestError

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from utils import llm
from utils import keys
from utils import string_utils


# Instantiate the AsyncOpenAI client
client = AsyncOpenAI(api_key=keys.OPENAI_KEY)

# Load prompts
df = pd.read_csv("all_unsure_prompts.csv")
all_prompts_to_eval = df["harmful_original_task"].tolist()[:]


# Lists to track results
policy_violation_prompts = []
successful_prompts = []
other_error_prompts = []

# Counter for policy violations
policy_violations_count = 0

# Lock for shared data if needed
lock = asyncio.Lock()

async def process_prompt(i, prompt):
    """
    Process a single prompt: generate the image via AsyncOpenAI,
    handle policy violations, add to global lists, etc.
    """
    global policy_violations_count
    
    try:
        print(f"Processing prompt {i+1}/{len(all_prompts_to_eval)}: {prompt[:50]}...")

        # Asynchronously generate the image
        response = await client.images.generate(
            model="dall-e-3",
            prompt=prompt,
            n=1,
            size="1024x1024"
        )
        
        # Record success
        async with lock:
            successful_prompts.append(prompt)
        print(f"Success for prompt {i+1}")

    except BadRequestError as e:
        error_message = str(e)

        # Check if it's a content policy violation
        if "content_policy_violation" in error_message:
            async with lock:
                policy_violations_count += 1
                policy_violation_prompts.append(prompt)
            print(f"Content policy violation for prompt {i+1}")
        else:
            async with lock:
                other_error_prompts.append({
                    "prompt": prompt,
                    "error": error_message
                })
            print(f"Other error for prompt {i+1}: {error_message[:100]}")

    except Exception as e:
        # Handle any other exceptions
        print(f"Unexpected error for prompt {i+1}: {str(e)}")
        async with lock:
            other_error_prompts.append({
                "prompt": prompt,
                "error": str(e)
            })


async def main():
    tasks = []
    # Create a task for each prompt
    for i, prompt in enumerate(all_prompts_to_eval):
        tasks.append(asyncio.create_task(process_prompt(i, prompt)))

    # Run them all concurrently
    await asyncio.gather(*tasks)

    # Print summary
    print("\nSummary:")
    print(f"Total prompts processed: {len(all_prompts_to_eval)}")
    print(f"Successful prompts: {len(successful_prompts)}")
    print(f"Content policy violations: {policy_violations_count}")
    print(f"Other errors: {len(other_error_prompts)}")

    # Save results to JSON
    results = {
        "policy_violation_count": policy_violations_count,
        "policy_violation_prompts": policy_violation_prompts,
        "other_error_prompts": other_error_prompts
    }

    with open("error_prompts.json", "w") as f:
        json.dump(results, f, indent=2)
    
    filename = "all_policy_violation_prompts.csv"
    filtered_df = df[df["harmful_original_task"].isin(policy_violation_prompts)]
    if os.path.exists(filename):
        filtered_df.to_csv(filename, mode='a', index=False, header=False)
    else:
        filtered_df.to_csv(filename, mode='w', index=False, header=True)
        
    all_other_prompts= [x["prompt"] for x in other_error_prompts]
    df[df["harmful_original_task"].isin(all_other_prompts)].to_csv("all_unsure_prompts.csv", index=False)
    
if __name__ == "__main__":
    asyncio.run(main())

