from openai import OpenAI
import json
import statistics
import time

# Setup
client = OpenAI(api_key="")  # Replace with API key

# Replace with your actual files
input_file = "" # .jsonlist
output_file = "" # .jsonlist
prompt_file = "" # .txt


def load_prompt_template(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


rewrite_prompt_template = load_prompt_template(prompt_file)


def count_words(text):
    return len(text.strip().split())


def get_balanced_target_length(choices):
    lengths = [count_words(choice["text"]) for choice in choices]
    median = statistics.median(lengths)
    midpoint = (min(lengths) + max(lengths)) / 2
    return round(median)

def build_user_prompt(sentence, target_length):
    return f"Sentence: {sentence}\nNumber of words targeted: {target_length}"


def rewrite_choice(prompt_template, sentence, target_length):
    developer_prompt = prompt_template
    user_prompt = build_user_prompt(sentence, target_length)

    for _ in range(3):  # retry up to 3 times if API fails
        try:
            response = client.chat.completions.create(
                model="o3",
                messages=[
                    {"role": "system", "content": developer_prompt},
                    {"role": "user", "content": user_prompt}
                ]
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            print(f"Retrying due to error: {e}")
            time.sleep(2)
    return sentence  # fallback to original if failure


def process_mcqs(input_path, output_path, prompt_template):
    rewritten = []
    with open(input_path, "r", encoding="utf-8") as f_in, open(output_path, "w", encoding="utf-8") as f_out:
        for line in f_in:
            entry = json.loads(line.strip())
            choices = entry.get("choices", [])
            print(f"reviewID: {entry['reviewID']} instanceID: {entry['instanceID']}")
            if not choices:
                assert False

            median_len = get_balanced_target_length(choices)
            print(f"Target word count: {median_len}")

            new_choices = []
            for choice in choices:
                original_text = choice["text"]
                rewritten_text = rewrite_choice(prompt_template, original_text, median_len)
                print(f"Original: {original_text}")
                print(f"Rewritten: {rewritten_text}\n")
                choice["text"] = rewritten_text
                new_choices.append(choice)

            entry["choices"] = new_choices
            f_out.write(json.dumps(entry) + "\n")
            rewritten.append(entry)
    return rewritten


# Run
_ = process_mcqs(input_file, output_file, rewrite_prompt_template)