import pandas as pd
import openai
import time
from typing import List
import os


from utils import write_result, filter_seen_solutions, call_oai_rm_llm
# get t from command line
import sys
T = float(sys.argv[1])

# Load prompts from CSV file


# Save prompts to file

        
max_length = 1000
        
prompt_format = """Write a 3 part story outline based on the following prompt: 

{prompt}

Format: 

**Title**: <TITLE>

**Setting**: <SETTING>

**Characters**: <CHARACTERS>

**Act 1:** <ACT 1 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

**Act 2:** <ACT 2 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

**Act 3:** <ACT 3 TITLE>
1. <Content>
2. <Content>
3. <Content>
...

THE END
"""


# in parallel make calls and write to jsonl file

import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor
import pandas as pd

import random

def get_specialized_prompt(entry):
    specialized_prompt = entry['prompt'] + "\n\n Additional instructions: \n" 
    for i, option in enumerate(entry['options'].values()): 
        choices = list(option.items())
        
        
        odds = choices[0][1]
        # sample boolean True with odds 
        first_taken = random.choices([True, False], weights=[odds, 1-odds])[0]
        if first_taken:
            specialized_prompt += f"{i+1}. {choices[0][0]}\n"
        else:
            specialized_prompt += f"{i+1}. {choices[1][0]}\n"
    return specialized_prompt

def process_prompt(entry):
    
    specialized_prompt1 = get_specialized_prompt(entry)
    specialized_prompt2 = get_specialized_prompt(entry)
        
        
    prompt_formatted1 = prompt_format.format(prompt=specialized_prompt1)
    prompt_formatted2 = prompt_format.format(prompt=specialized_prompt2)
    output1 = call_oai_rm_llm(prompt_formatted1, n=1, temperature=T, model_id="gpt-4o")
    output2 = call_oai_rm_llm(prompt_formatted2, n=1, temperature=T, model_id="gpt-4o")
    if not output1 or not output2:
        return None
        
    result = {
        "prompt": entry['prompt'],
        "specialized_prompt": [specialized_prompt1, specialized_prompt2],
        "completions": [output1, output2]
    }
    return result


def process_and_write(prompt, output_file, i):
    result = process_prompt(prompt)
    if result:
        write_result(result, output_file)
    print(f"Processed {i} prompts")


infile = "odds.jsonl"
output_file = f'final_data/ss_completions_temp_{str(T)}.jsonl'

with open(infile, 'r') as f:
    entries = [json.loads(line) for line in f.readlines()]

entries = filter_seen_solutions(entries, output_file)
if len(entries) == 0:
    print(f"No new prompts to process for temperature {T}")
    exit()

# Process prompts in parallel
with ThreadPoolExecutor(max_workers=50) as executor:
    futures = [
        executor.submit(process_and_write, entry, output_file, i)
        for i, entry in enumerate(entries)
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()




    
    
