import json
from pathlib import Path
from openai import OpenAI

# Load prompts
def load_prompt_template(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

simple_forward_prompt = load_prompt_template("simple_forward.txt")
simple_backward_prompt = load_prompt_template("simple_backward.txt")
complex_forward_prompt = load_prompt_template("complex_forward.txt")
complex_forward_nonAttack_prompt = load_prompt_template("complex_forward_nonAttack.txt")
complex_backward_prompt = load_prompt_template("complex_backward.txt")
complex_linear_prompt = load_prompt_template("complex_linear.txt")

# Load triplets
with open("test_input.jsonlist", "r", encoding="utf-8") as f:
    triplets = [json.loads(line.strip()) for line in f]

# Plan logic per q_type
def get_generation_plan(qtype, props):
    p1, p2, p3 = props["prop1"], props["prop2"], props["prop3"]
    if qtype == "1.1":
        return [("complex_backward", (p3, p2)), ("complex_linear", (p2, p3)), ("forward", p3)]
    elif qtype == "1.2":
        return [("complex_backward", (p3, p1)), ("backward", p1), ("forward", p3)]
    elif qtype == "1.3":
        return [("complex_forward", (p1, p2)), ("complex_linear", (p1, p2)), ("backward", p1)]
    elif qtype == "2.1":
        return [("complex_forward", (p2, p3)), ("forward", p3), ("complex_forward_nonAttack", (p3, "D2"))]
    elif qtype == "2.2":
        return [("complex_forward", (p1, p3)), ("forward", p3), ("complex_forward_nonAttack", (p3, "D2"))]
    elif qtype == "2.3":
        return [("complex_backward", (p1, p2)), ("combined_backward", (p1, p2)), ("complex_backward", (p2, p1))]
    elif qtype == "3.1":
        return [("complex_forward", (p2, p3)), ("combined_forward", (p2, p3)), ("complex_forward", (p3, p2))]
    elif qtype == "3.2":
        return [("complex_backward", (p3, p1)), ("backward", p1), ("complex_backward", (p1, "D2"))]
    elif qtype == "3.3":
        return [("complex_backward", (p2, p1)), ("backward", p1), ("complex_backward", (p1, "D2"))]
    else:
        raise ValueError(f"Unknown q_type: {qtype}")

# Instantiate client
client = OpenAI(api_key="") # Replace with your actual API key

# Generate from model
def generate_response(prompt_template, target_text, is_related_to_two=False):
    if target_text is None:
        return None

    developer_prompt = prompt_template

    if is_related_to_two:
        t1, t2 = target_text
        user_prompt = f"Sentence A: {t1}\nSentence B: {t2}"
    else:
        user_prompt = f"Sentence A: {target_text}"

    response = client.chat.completions.create(
        model="o3",
        messages=[
            {"role": "developer", "content": developer_prompt},
            {"role": "user", "content": user_prompt}
        ],
        # temperature=0,
        # max_tokens=100,
    )

    if is_related_to_two:
        t1, t2 = target_text
        print(f"Sentence A: {t1}\nSentence B: {t2}")
    
    else:
        print(f"Sentence A: {target_text}")
        
    return response.choices[0].message.content.strip()

# Main generation function
def generate_all_distractors(triplets, output_path):
    with open(output_path, "w", encoding="utf-8") as f_out:
        updated = []
        for entry in triplets:
            print(f"reviewID: {entry.get('reviewID', 'N/A')} instanceID: {entry.get('instanceID', 'N/A')} q_type: {entry['q_type']}")
            props = {"prop1": entry["prop1"], "prop2": entry["prop2"], "prop3": entry["prop3"]}
            plan = get_generation_plan(entry["q_type"], props)

            gen_outputs = {}
            for idx, (direction, target) in enumerate(plan):

                # EXCEPTION: COMBINED FORWARD/BACKWARD, SEND AS SINGLE PROP
                if direction == "combined_forward" or direction == "combined_backward":
                    first  = target[0].rstrip() # remove trailing whitespace
                    if first.endswith('.'):
                        first = first[:-1] + ';' # replace final period with comma

                    second = target[1].lstrip() # remove leading whitespace
                    # second = second[0].lower() + second[1:]       # down-case the first letter

                    target_text = f"{first} {second}"
                else:
                    target_text = target
                
                is_related_to_two = isinstance(target_text, tuple)
                if is_related_to_two:
                    resolved = []
                    for t in target:
                        if t == "D2":
                            d2 = gen_outputs.get("B")
                            assert(d2 is not None)
                            resolved.append(d2)
                        else:
                            resolved.append(t)
                    
                    target_text = tuple(resolved)

                if target_text is None or (is_related_to_two and (target_text[0] is None or target_text[1] is None)):
                    gen_outputs[chr(65 + idx)] = None
                    continue

                # Select prompt based on direction
                if direction == "forward" or direction == "combined_forward":
                    prompt = simple_forward_prompt
                elif direction == "backward" or direction == "combined_backward":
                    prompt = simple_backward_prompt
                elif direction == "complex_forward":
                    prompt = complex_forward_prompt
                elif direction == "complex_backward":
                    prompt = complex_backward_prompt
                elif direction == "complex_linear":
                    prompt = complex_linear_prompt
                elif direction == "complex_forward_nonAttack":
                    prompt = complex_forward_nonAttack_prompt
                else:
                    raise ValueError(f"Unknown direction: {direction}")
                
                print(direction)

                output = generate_response(prompt, target_text, is_related_to_two)
                print(output)
                print()
                gen_outputs[chr(65 + idx)] = output

            # Save 3 distractors
            entry["d1"] = gen_outputs.get("A")
            entry["d2"] = gen_outputs.get("B")
            entry["d3"] = gen_outputs.get("C")
            updated.append(entry)

            # Write to file immediately
            f_out.write(json.dumps(entry) + "\n")

    return updated

# Run and save
output_path = "" # YOUR OUTPUT .jsonlist HERE
generate_all_distractors(triplets, output_path)
