import openai
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib
import pickle
import os

import json

from joblib import Parallel, delayed
import sys; sys.path.insert(0, '..')
from utils import levels, prompt, reminder, llm_as_a_judge

def limo_conversation_converter(orig_conv: str):
    # Removes prompt
    start = orig_conv.find("# Solve the following problem")
    conv = orig_conv[start:]

    # Rmoves s1 inserts
    s1_insert = "Quick check: am I doing redundant work? (yes/no):"
    conv = conv.replace(s1_insert, "")

    # Remove Bobs starter
    bobs_starter = "Hi, I'm Bob."
    conv = conv.replace(bobs_starter, "")
    
    # Rmove Alice starter
    alice_starters = ["Hi, I'm Alice. Here's how we can tackle this problem:",
                      "Hi, I'm Alice. Here's how we can approach this problem:",]
    for alice_starter in alice_starters:
        conv = conv.replace(alice_starter, "")

    return conv

def job_handler(budget, task, n_judges=3, model="gpt-4o"):
    path = source_path.format(budget=budget, task=task)
    if not os.path.isfile(path):
        return (f"task_{task}", (None, None))
    with open(path, "r") as f:
        conversation = "".join(f.readlines())
        conversation = limo_conversation_converter(conversation)

    ## Long steps
    steps = {"Common": [], "Alice": [], "Bob": []}
    prev = "Common"
    for step in conversation.split("\n\n"):
        if not step:
            continue
        if step[:3] == "**A":
            prev = "Alice"
        elif step[:3] == "**B":
            prev = "Bob"
        steps[prev].append(step + "\n")
    
    conversation = "".join(steps["Common"] + ["\n"] + steps["Alice"] + ["\n"] + steps["Bob"])
    ###

    score, scores, explanations, errors = llm_as_a_judge(client, conversation, n_judges=n_judges, model=model)
    logs = (scores, explanations, errors)
    return (f"task_{task}", (score, logs))

if __name__ == '__main__':
    with open("../openai_config.json", "r") as f:
        openai_config = json.loads("".join(f.readlines()))

    client = openai.OpenAI(
        **openai_config
    )

    paths = {
        
        # "limo-QwQ-main": 
        # ("../generations/LIMO/QwQ-main/evals_data/limo/QwQ-32B-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
        #  "../judge_results/LIMO/QwQ-main/it1_long/QwQ-32B-seed-42-budget-{budget}.pkl"),

        # "limo-QwQ-ablation-interleaved": 
        # ("../generations/LIMO/QwQ-ablation-interleaved/evals_data/limo/QwQ-32B-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
        #  "../judge_results/LIMO/QwQ-ablation-interleaved/it1_long/QwQ-32B-seed-42-budget-{budget}.pkl"),

        # "limo-phi4-main": 
        # ("../generations/LIMO/phi4-main/evals_data/limo/Phi-4-reasoning-plus-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
        #  "../judge_results/LIMO/phi4-main/it1_long/Phi-4-reasoning-plus-seed-42-budget-{budget}.pkl"),


        "limo-phi4-ablation-interleaved": 
        ("../generations/LIMO/phi4-ablation-interleaved/evals_data/limo/Phi-4-reasoning-plus-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
         "../judge_results/LIMO/phi4-ablation-interleaved/it1_long/Phi-4-reasoning-plus-seed-42-budget-{budget}.pkl"),
        
        
        # "limo-Qwen-8B-ablation-interleaved": 
        # ("../generations/LIMO/Qwen-8B-ablation-interleaved/evals_data/limo/Qwen3-8B-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
        #  "../judge_results/LIMO/Qwen-8B-ablation-interleaved/it1_long/Qwen3-8B-seed-42-budget-{budget}.pkl"),

        #  "limo-QwQ-ablation-interleaved": 
        # ("../generations/LIMO/QwQ-ablation-interleaved/evals_data/limo/QwQ-32B-seed-42-budget-{budget}-hogwild/Task_{task}.txt",
        #  "../judge_results/LIMO/QwQ-ablation-interleaved/it1_long/QwQ-32B-seed-42-budget-{budget}.pkl"),
    }

    for name, (source_path, save_path) in paths.items():
        budget_grid = [
            # 512, 
            2048,
            4096,
        ] 
        task_grid = range(817)
        n_jobs = 200

        print(name)
        print(f"From:{source_path} To:{save_path}")
        for budget in budget_grid:
            print(f"{budget=}")
            with tqdm_joblib(tqdm(desc="Processing", total=len(task_grid))):
                results_list = Parallel(n_jobs=n_jobs, backend="threading")(
                delayed(job_handler)(budget, task) for task in task_grid
                )

            scores_log = dict(results_list)
            path = save_path.format(budget=budget)
            os.makedirs(os.path.dirname(path), exist_ok=True)
            with open(path, 'wb') as f:
                pickle.dump(scores_log, f)