import pandas as pd
import openai
import time
from typing import List

import os
import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor
import pandas as pd

from utils import question_to_category

from utils import call_oai_rm_llm
# Load prompts from CSV file
df = pd.read_csv("prompts.csv")

# Get prompts
prompts = df['prompt'].tolist()
        
max_length = 5000
        
prompt_format = """I am tasked with the following request: 
{user_request}

Help me brainstorm how to respond to the user request by providing a list of True/False properties the solution may or may not have. 
List the properties in the following format where + indicates the property and - indicates it's negation.
Avoid abbreviation in final answer. Minimum 4 properties.

FINAL ANSWER: 

1. **<Name of Property>**
+ The {category} ... <Option 1>
- The {category} ... <Option 2>

Ensure all properties are listed are sentences that are either True or False.
"""

def parse_result(result):
    # result is in the format of 1. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    # 2. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    # ...
    # 5. <Name of Choice> - The story ... <Option 1> - The story ... <Option 2>
    
    options = [] # name of choice, option 1, option 2
    
    # ignore everything until "FINAL ANSWER:"
    result = result.split("FINAL ANSWER:")
    if len(result) < 2:
        return []
    result = result[1]
    
    # split by line
    lines = result.split('\n')
    current_option = {}
    for line in lines:
        if len(line) < 2:
            continue
        if line[0].isdigit() and line[1] == '.':
            # we're in a new option
            if current_option:
                options.append(current_option)
            current_option = {}
            current_option['name'] = line[2:].strip().replace('**', '')
            current_option['options'] = []
            
        if line.strip().startswith('-'):
            current_option['options'].append(line.strip().replace('-', '').strip())
        elif line.strip().startswith('+'):
            current_option['options'].append(line.strip().replace('+', '').strip())
    else: 
        if current_option:
            options.append(current_option)
    return options

# in parallel make calls and write to jsonl file


def process_prompt(prompt, idx, max_attempts=5):
    category = question_to_category(prompt, natural = True)
    prompt_formatted = prompt_format.format(user_request=prompt, category=category)
    for i in range(max_attempts):
        outputs = call_oai_rm_llm(prompt_formatted, n=1, temperature=0.5, model_id="gpt-4o", max_tokens=max_length)


        if not outputs:
            return None
        
        parsed_result = parse_result(outputs)
        if len(parsed_result) <= 3:
            continue
        else:
            
            result = {
                "idx": idx, 
                "prompt": prompt,
                "raw_result": outputs,
                "options": parse_result(outputs)
            }
            break
    else: 
        print (f"Failed to get a valid result for prompt {prompt} after {max_attempts} attempts")
        return None
    
    
    return result

def write_result(result, output_file):
    if result is None:
        return
        
    with open(output_file, 'a') as f:
        fcntl.flock(f.fileno(), fcntl.LOCK_EX)
        try:
            f.write(json.dumps(result) + '\n')
        finally:
            fcntl.flock(f.fileno(), fcntl.LOCK_UN)

def process_and_write(prompt, output_file, i):
    result = process_prompt(prompt, i)
    if result:
        write_result(result, output_file)
    print(f"Processed {i} of {len(prompts)} prompts")

# Read prompts from CSV
prompts_df = pd.read_csv('prompts.csv')
prompts = prompts_df['prompt'].tolist()

output_file = 'modified_prompts_abl.jsonl'



if not os.path.exists(output_file):
    # Create/clear the output file
    open(output_file, 'w').close()

prompts_to_process = list(enumerate(prompts))

with open(output_file, 'r') as f:
    entries = [json.loads(line) for line in f.readlines()]
    seen_idx = set(entry['idx'] for entry in entries)

prompts_to_process = [
    (i, prompt) for i, prompt in prompts_to_process
    if i not in seen_idx
]

print (f"Processing {len(prompts_to_process)} prompts out of {len(prompts)} total prompts")
# Process prompts in parallel
with ThreadPoolExecutor(max_workers=50) as executor:
    futures = [
        executor.submit(process_and_write, prompt, output_file, i)
        for i, prompt in prompts_to_process
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()




    
    
