import openai
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib
import pickle
import os

import json

from joblib import Parallel, delayed

import sys; sys.path.insert(0, '..')
from utils import levels, prompt, reminder, llm_as_a_judge

def merge_steps(steps_A, steps_B):
    i, j = 0, 0
    tokens_A, tokens_B = 0, 0
    n, m = len(steps_A), len(steps_B)
    conv = []
    while i < n or j < m:
        if (i < n and tokens_A <= tokens_B) or (j == m):
            conv.append(steps_A[i])
            tokens_A += len(steps_A[i])
            i += 1
        else:
            conv.append(steps_B[j])
            tokens_B += len(steps_B[j])
            j += 1
        conv.append("\n\n")
    return "".join(conv)

def job_handler(budget, task, n_judges=3, model="gpt-4o"):
    seeds = [0, 42]
    path_A = source_path.format(budget=budget, task=task, seed=seeds[0])
    path_B = source_path.format(budget=budget, task=task, seed=seeds[1])

    if not os.path.isfile(path_A) or not os.path.isfile(path_B):
        return (f"task_{task}", (None, None))
    
    with open(path_A, "r") as f:
        gen_A = "".join(f.readlines())
        raw_steps_A = [step + "\n" for step in gen_A.split("\n\n") if step]
        long_steps_A = "".join(["** Alice: **"] + raw_steps_A)
        short_steps_A = ["** Alice: **" + step + "\n" for step in raw_steps_A if step]

    with open(path_B, "r") as f:
        gen_B = "".join(f.readlines())
        raw_steps_B = [step + "\n" for step in gen_B.split("\n\n") if step]
        long_steps_B = "".join(["** Bob: **"] + raw_steps_B)
        short_steps_B = ["** Bob: **" + step + "\n" for step in raw_steps_B if step]

    # conversation = merge_steps(short_steps_A, short_steps_B)
    conversation = long_steps_A + "\n\n" + long_steps_B
    # print(conversation)
    score, scores, explanations, errors = llm_as_a_judge(client, conversation, n_judges=n_judges, model=model)
    logs = (scores, explanations, errors)
    return (f"task_{task}", (score, logs))

if __name__ == '__main__':
    with open("../openai_config.json", "r") as f:
        openai_config = json.loads("".join(f.readlines()))

    client = openai.OpenAI(
        **openai_config
    )
    
    source_path = "../generations/LIMO/Qwen-8B-ablation-independent/evals_data/limo/Qwen3-8B-seed-{seed}-budget-{budget}-hogwild/Task_{task}.txt"
    save_path = "../judge_results/LIMO/Qwen-8B-ablation-independent/it1_long/Qwen3-8B-seed-{budget}.pkl"
    
    
    budget_grid = [
            # 512, 
            2048,
            4096,
    ]
    task_grid = range(817)
    n_jobs = 200

    print(save_path)
    for budget in budget_grid:
        print(f"{budget=}")
        with tqdm_joblib(tqdm(desc="Processing", total=len(task_grid))):
            results_list = Parallel(n_jobs=n_jobs, backend="threading")(
            delayed(job_handler)(budget, task) for task in task_grid
            )

        scores_log = dict(results_list)
        print("Not none tasks:", len([key for key, (score, log) in scores_log.items() if not score is None]))
        path = save_path.format(budget=budget)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'wb') as f:
            pickle.dump(scores_log, f)