import pandas as pd
import openai
import time
from typing import List

import os
import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor
import pandas as pd

from utils import question_to_category

from utils import call_oai_rm_llm, call_anth_rm_llm
# Load prompts from CSV file
df = pd.read_csv("prompts.csv")

# Get prompts
prompts = df['prompt'].tolist()
       
anth = True 
max_length = 5000
        
prompt_format = """I am tasked with the following request: 
{user_request}

Help me brainstorm how to respond to the user request by providing a list of True/False properties the solution may or may not have. Use the following step-by-step to come up with good properties: 

1. If you were playing 20 questions, what's a good first question to ask that would split the possibilites in half? 
List at least 4 questions and their corresponding properties.
Question: <Description>

2. Rewrite each question as a True/False property that's true for one half and false for the other.
Question: <Description>
True/False Property: <Property Description>

3. For each property, come up with an example that would satisfy the property.
Property: <Description>
Example: <Description>
Is it a valid answer to the user's request? <Yes/No>

4. For each property, come up with an example that would not satisfy the property.
Property: <Description>
Example: <Description>
Is it a valid answer to the user's request? <Yes/No>

5. For each property, list whether we should include it or not in the final list of properties. Do not include ones where an example from above is not valid.
Property: <Description>
Include in final list? <Yes/No>

List the properties in the following format where + indicates the property and - indicates it's negation.
Avoid abbreviation in final answer. Minimum 4 properties.

FINAL ANSWER: 

1. **<Name of Property>**
+ The {category} ... <Option 1>
- The {category} ... <Option 2>

Ensure all properties are listed are sentences that are either True or False.
"""

def parse_result(result):
    # result is in the format of 1. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    # 2. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    # ...
    # 5. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    
    options = [] # name of choice, option 1, option 2
    
    # ignore everything until "FINAL ANSWER:"
    result = result.split("FINAL ANSWER:")
    if len(result) < 2:
        return []
    result = result[1]
    
    # split by line
    lines = result.split('\n')
    current_option = {}
    for line in lines:
        if len(line) < 2:
            continue
        if line[0].isdigit() and line[1] == '.':
            # we're in a new option
            if current_option:
                options.append(current_option)
            current_option = {}
            current_option['name'] = line[2:].strip().replace('**', '')
            current_option['options'] = []
            
        if line.strip().startswith('-'):
            current_option['options'].append(line.strip().replace('-', '').strip())
        elif line.strip().startswith('+'):
            current_option['options'].append(line.strip().replace('+', '').strip())
    else: 
        if current_option:
            options.append(current_option)
    return options

# in parallel make calls and write to jsonl file


def process_prompt(prompt, idx, max_attempts=5):
    category = question_to_category(prompt, natural = True)
    prompt_formatted = prompt_format.format(user_request=prompt, category=category)
    for i in range(max_attempts):
        fn = call_anth_rm_llm if anth else call_oai_rm_llm 
        model = "claude-3-7-sonnet-20250219" if anth else "gpt-4o"
        outputs = fn(prompt_formatted, n=1, temperature=0.5, model_id=model, max_tokens=max_length)
        if not outputs:
            return None
        
        parsed_result = parse_result(outputs)
        if len(parsed_result) <= 3:
            continue
        else:
            
            result = {
                "idx": idx, 
                "prompt": prompt,
                "raw_result": outputs,
                "options": parse_result(outputs)
            }
            break
    else: 
        print (f"Failed to get a valid result for prompt {prompt} after {max_attempts} attempts")
        return None
    
    
    return result

def write_result(result, output_file):
    if result is None:
        return
        
    with open(output_file, 'a') as f:
        fcntl.flock(f.fileno(), fcntl.LOCK_EX)
        try:
            f.write(json.dumps(result) + '\n')
        finally:
            fcntl.flock(f.fileno(), fcntl.LOCK_UN)

def process_and_write(prompt, output_file, i):
    result = process_prompt(prompt, i)
    if result:
        write_result(result, output_file)
    print(f"Processed {i} of {len(prompts)} prompts")

# Read prompts from CSV
prompts_df = pd.read_csv('prompts.csv')
prompts = prompts_df['prompt'].tolist()

anth_tag = '_anth' if anth else ''
output_file = f'modified_prompts{anth_tag}.jsonl'



if not os.path.exists(output_file):
    # Create/clear the output file
    open(output_file, 'w').close()

prompts_to_process = list(enumerate(prompts))

with open(output_file, 'r') as f:
    entries = [json.loads(line) for line in f.readlines()]
    seen_idx = set(entry['idx'] for entry in entries)

prompts_to_process = [
    (i, prompt) for i, prompt in prompts_to_process
    if i not in seen_idx
]

print (f"Processing {len(prompts_to_process)} prompts out of {len(prompts)} total prompts")
# Process prompts in parallel
with ThreadPoolExecutor(max_workers=50) as executor:
    futures = [
        executor.submit(process_and_write, prompt, output_file, i)
        for i, prompt in prompts_to_process
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()




    
    
