
import asyncio

import openai
from dotenv import load_dotenv
load_dotenv("../.env")
import os
import json
print(os.getenv("OPENROUTER_API_KEY"))

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from generation.main import load_all_configs
from generation.openai_compatible import OpenAICompatibleArgs, openai_compatible_completion
# Create an OpenRouter client (assumes OPENROUTER_API_KEY is set in your environment)
openrouter_client = openai.OpenAI(
    api_key=os.getenv("OPENROUTER_API_KEY"),
    base_url="https://openrouter.ai/api/v1",
)
openai_client = openai.OpenAI(
    api_key=os.getenv("OPENAI_API_KEY")
)

from tqdm import tqdm

NUM_SAMPLES = 5  # Number of completions per submission

async def process_submission_multiple(submission, model_name_to_provider, configs, num_samples=NUM_SAMPLES):
    results = []
    try:
        provider = model_name_to_provider[submission["test_model"]]
        if provider == "openai":
            client = openai_client
        else:
            client = openrouter_client

        args = OpenAICompatibleArgs(
            model=submission["test_model"],
            messages=submission["messages"],
            bucket=None,
            toolArgs=configs[submission["behavior"]],
            system_prompt=submission["system_prompt"]
        )

        for _ in range(num_samples):
            try:
                # If openai_compatible_completion is synchronous, run in thread
                result = await asyncio.to_thread(
                    openai_compatible_completion,
                    client=client,
                    args=args,
                    provider=provider
                )
            except Exception:
                import traceback
                print(f"Error in one of the completions for {submission['test_model']}")
                print(traceback.format_exc())
                result = "ERROR"
            results.append(result)
        return results
    except Exception:
        print(f"Error processing submission {submission['test_model']}")
        import traceback
        print(traceback.format_exc())
        return ["ERROR"] * num_samples

async def main_async(transfer_submissions_path, output_path, num_samples=NUM_SAMPLES, limit=None):
    configs = load_all_configs()

    model_name_to_provider_path = "../data/model_name_to_provider.json"
    with open(model_name_to_provider_path, "r") as f:
        model_name_to_provider = json.load(f)

    with open(transfer_submissions_path, "r") as f:
        transfer_submissions = json.load(f)
    if limit is not None:
        transfer_submissions = transfer_submissions[:limit]

    # Launch all tasks asynchronously
    tasks = []
    for idx, submission in enumerate(transfer_submissions):
        task = asyncio.create_task(
            process_submission_multiple(submission, model_name_to_provider, configs, num_samples)
        )
        tasks.append((idx, task))

    # tqdm does not natively support asyncio.gather, so we gather as they complete
    for idx, task in tqdm(tasks, desc="Processing submissions"):
        results = await task
        transfer_submissions[idx]["responses"] = results  # Store all results in 'responses' array

    with open(output_path, "w") as f:
        json.dump(transfer_submissions, f, indent=4)

def main(transfer_submissions_path, output_path, num_samples=NUM_SAMPLES, limit=None):
    asyncio.run(main_async(transfer_submissions_path, output_path, num_samples, limit))

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Process transfer submissions.")
    parser.add_argument('--input_path', required=True, help='Path to the transfer submissions file')
    parser.add_argument('--output_path', required=True, help='Path to the output file')
    parser.add_argument('--num_samples', type=int, default=1, help='Number of samples to process')
    parser.add_argument('--limit', type=int, default=None, help='Limit the number of submissions to process')

    args = parser.parse_args()
    transfer_submissions_path = args.input_path
    output_path = args.output_path
    num_samples = args.num_samples
    limit = args.limit

    main(transfer_submissions_path, output_path, num_samples=num_samples, limit=limit)