import os
import json
import random
import argparse
import traceback
from tqdm import tqdm
from openai import OpenAI
from timeit import default_timer as timer
from utils.ordered_commongen import complete_prompt_ordered_commongen

def main(
        model,
        dataset,
        seed=42,
        temperature=0.7,
        max_completion_tokens=2048,
    ):
    
    random.seed(seed)
    candidates = []
    times = []
    errors = []
    tracebacks_list = []
    
    for i in tqdm(range(len(dataset)), desc=f"Processing {model}"):
        try:
            concepts = dataset[i]["concepts"]
            concepts_str = "\"{}\"".format(", ".join(concepts))
            prompt = complete_prompt_ordered_commongen(concepts_str)
            kwargs = { 
                      "max_completion_tokens": max_completion_tokens,}
            if model not in ["o1", "o2","o3","o4"]:
                kwargs["temperature"] = temperature
            start_time = timer()

            
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are a helpful AI that generates descriptive sentences using given concepts."},
                    {"role": "user", "content": prompt}
                ],
                **kwargs,
            )

            end_time = timer()
            sentence = response.choices[0].message.content.strip()

            candidates.append({
                "id": dataset[i]["id"],
                "concepts": concepts,
                "sentence": sentence
            })

            times.append({
                "id": dataset[i]["id"],
                "time": end_time - start_time
            })

        except Exception as e:
            print(f"Error for ID {dataset[i]['id']}: {str(e)}")
            print(f"traceback: {traceback.format_exc()}")
            errors.append({
                "id": dataset[i]["id"],
                "error": str(e)
            })
            tracebacks_list.append({
                "id": dataset[i]["id"],
                "traceback": traceback.format_exc()
            })

    def change_str(f):
        return str(f).replace('.', '_').replace('-',"_")

    model_name = change_str(model)
    with open(os.path.join(results_dir, f'{model_name}_candidates.json'), 'w') as f:
        json.dump(candidates, f, indent=4)

    print("\nDone!")
    
if __name__ == "__main__":
    client = OpenAI(api_key="YOUR_API_KEY") # Replace with your OpenAI API key

    with open('data/Ordered CommonGen/ordered_commongen.json', 'r') as f:
        dataset = json.load(f)
    
    parser = argparse.ArgumentParser(description="Generate sentences using OpenAI API.")
    parser.add_argument("--model", type=str, default="gpt-3.5-turbo", help="OpenAI model to use.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling.")
    parser.add_argument("--max_completion_tokens", type=int, default=2048, help="Max tokens for completion.")
    args = parser.parse_args()
    results_dir = f"./LLM/results/openai"
    os.makedirs(results_dir, exist_ok=True)

    main(
        model=args.model,
        dataset=dataset,
        seed=args.seed,
        temperature=args.temperature,
        max_completion_tokens=args.max_completion_tokens
    )