import pandas as pd
import openai
import time
from typing import List
import os


from utils import write_result, filter_seen_solutions, call_oai_rm_llm
# get t from command line
import sys
T = float(sys.argv[1])

# Load prompts from CSV file
df = pd.read_csv("prompts.csv")

# Get prompts
prompts = df['prompt'].tolist()

# Save prompts to file
with open("prompts.txt", "w") as file:
    for prompt in prompts:
        file.write(prompt + "\n")
        
max_length = 1000
        
prompt_format = """Write a 3 part story outline based on the following prompt: 

{prompt}

Format: 

**Title**: <TITLE>

**Setting**: <SETTING>

**Characters**: <CHARACTERS>

**Act 1:** <ACT 1 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

**Act 2:** <ACT 2 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

**Act 3:** <ACT 3 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

THE END
"""


# in parallel make calls and write to jsonl file

import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor
import pandas as pd

def process_prompt(prompt):
    prompt_formatted = prompt_format.format(prompt=prompt)
    outputs = call_oai_rm_llm(prompt_formatted, n=2, temperature=T, model_id="gpt-4o")
    
    if not outputs:
        return None
        
    result = {
        "prompt": prompt,
        "completions": outputs
    }
    return result


def process_and_write(prompt, output_file, i):
    result = process_prompt(prompt)
    if result:
        write_result(result, output_file)
    print(f"Processed {i} of {len(prompts)} prompts")

# Read prompts from CSV
prompts_df = pd.read_csv('prompts.csv')
prompts = prompts_df['prompt'].tolist()

output_file = f'final_data/baseline_completions_temp_{str(T)}.jsonl'

if os.path.exists(output_file):
    raise Exception("Output file already exists")
# Create/clear the output file
open(output_file, 'w').close()


# Process prompts in parallel
with ThreadPoolExecutor(max_workers=50) as executor:
    futures = [
        executor.submit(process_and_write, prompt, output_file, i)
        for i, prompt in enumerate(prompts)
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()




    
    
