import pandas as pd
import openai
import time
from typing import List
import re
import os

import json
import threading
import fcntl
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
from utils import filter_seen_solutions, question_to_category

from utils import call_oai_rm_llm, SYSTEM_SUPERFORECASTER_0, call_anth_rm_llm
# Load prompts from CSV file
df = pd.read_csv("prompts.csv")

# Get prompts
prompts = df['prompt'].tolist()
anth = True

# Save prompts to file
with open("prompts.txt", "w") as file:
    for prompt in prompts:
        file.write(prompt + "\n")
        
max_length = 1000
        
prompt_format = """I am tasked to estimate the probability that a random {category} picked by participants based on the prompt {request} is A) {axisA} or B) {axisB}. 
                  
Instructions:
1. Provide a reason or example why the answer might be A based on the prompt.
{{ Insert your thoughts }}

2. Provide a reason or example why the answer might be B based on the prompt.
{{ Insert your thoughts }}

3. Rate the strength of each of the reasons given in the last two responses. Think like a superforecaster (e.g. Nate Silver).
{{ Insert your rating of the strength of each reason }}

4. Aggregate your considerations.
{{ Insert your aggregated considerations }}

5. Rate which one is more likely. Output your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. 
0 is A, 0.3 is A is more likely than B, 0.5 is equal, 0.7 is B is more likely than A, 1 is B. 
{{ Insert your answer }}
"""

def parse_result(result):
    # things between ** and ** regex number
    options = re.findall(r'\*([0-9]*\.?[0-9]+)\*', result)
    return options[0]

# in parallel make calls and write to jsonl file

def process_option(option, prompt, max_attempts=5):
    option_title = option['name']
    option_results = {}
    opt_choices = option['options']
    # turn the first character to lowercase
    def get_opt(opt_choice):
        opt = opt_choice[0].lower() + opt_choice[1:]
        opt = opt.strip()
        if opt[-1] == '.':
            opt = opt[:-1]
        return opt
    
    optA = get_opt(opt_choices[0])
    optB = get_opt(opt_choices[1])
    category = question_to_category(prompt)
    prompt_formatted = prompt_format.format(request=prompt, axisA = optA, axisB = optB, category = category )
    for i in range(max_attempts):
        fn = call_anth_rm_llm if anth else call_oai_rm_llm 
        model = "claude-3-7-sonnet-20250219" if anth else "gpt-4o"
        outputs = fn(prompt_formatted, n=1, temperature=0.2, model_id=model, system_prompt=SYSTEM_SUPERFORECASTER_0)
        if not outputs:
            return None
        
        parsed_result = parse_result(outputs)
        
        # if string is float
        if parsed_result.replace('.', '', 1).isdigit():
            odds = float(parsed_result)
            option_results[optA] = 1 - odds
            option_results[optB] = odds
            break
        else:
            print (f"Failed to parse result for {option_title} {parsed_result}")
            continue
    else: 
        return None
    return option_results


def process_prompt(entry, max_attempts=5):
    idx = entry['idx']
    prompt = '"' + entry['prompt'] + '"'
    all_results = {}
    with ThreadPoolExecutor(max_workers=5) as executor:
        futures = []
        for option in entry['options']:
            future = executor.submit(process_option, option, prompt)
            futures.append((option['name'], future))
        
        for option_title, future in futures:
            option_results = future.result()
            if option_results is not None:
                all_results[option_title] = option_results
            
                
    result = {
        "idx": idx,
        "prompt": entry['prompt'],
        "options": all_results, # has odds
        
    }
        
    return result

def write_result(result, output_file):
    if result is None:
        return
        
    with open(output_file, 'a') as f:
        fcntl.flock(f.fileno(), fcntl.LOCK_EX)
        try:
            f.write(json.dumps(result) + '\n')
        finally:
            fcntl.flock(f.fileno(), fcntl.LOCK_UN)

def process_and_write(prompt, output_file, i):
    result = process_prompt(prompt)
    if result:
        write_result(result, output_file)
    print(f"Processed {i} of {len(prompts)} prompts")

anth_tag = '_anth' if anth else ''
# Read prompts from CSV
with open (f'modified_prompts{anth_tag}.jsonl', 'r') as f:
    entries = f.readlines()
    entries = [json.loads(entry) for entry in entries]
    

output_file = f'odds{anth_tag}.jsonl'

if os.path.exists(output_file):
    entries = filter_seen_solutions(entries, output_file)
else:
    # Create/clear the output file
    open(output_file, 'w').close()

print (f"Processing {len(entries)} prompts out of {len(prompts)} total prompts")

# Process prompts in parallel
with ThreadPoolExecutor(max_workers=20) as executor:
    futures = [
        executor.submit(process_and_write, entry, output_file, i)
        for i, entry in enumerate(entries)
    ]
    
    # Wait for all futures to complete
    for future in futures:
        future.result()




    
    
