from nightjar import nj_llm_factory


def main(prompt: str, expression: str, nj_llm) -> str:
    balanced_expression = nj_llm(
        "Insert/repair parentheses based on <prompt> in the <expression> string. Give just the "
        "balanced expression\n<prompt>{prompt}</prompt>\n<expression>{expression}</expression>"
    )
    return balanced_expression


#### Tests ####
import logging
from typing import Any, Dict, List, Tuple

logging.basicConfig(level=logging.INFO)


def run(
    model_name: str,
) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Any], Dict[str, bool], Dict[str, str]]:
    nj_llm, usage = nj_llm_factory(model_name, max_calls=100)

    inps = [
        (
            "Apply f to x after incrementing x by 1, then divide by g of x minus 1",
            "f x + 1 / g x - 1",
            "f(x + 1) / (g(x) - 1)",
        ),
        (
            "Regex should capture just the area code of the phone number",
            "^(\d{3}-\d{3}\d{4}$",
            "^(\d{3})-\d{3}\d{4}$",
        ),
    ]
    outputs = {}
    errors = {}
    llm_judge_outputs = {}
    hard_results = {}

    for i, (prompt, expression, answer) in enumerate(inps):
        outputs[f"test_{i}"] = None
        errors[f"test_{i}"] = None
        llm_judge_outputs[f"test_{i}"] = None
        hard_results[f"test_{i}"] = False

        try:
            outputs[f"test_{i}"] = main(prompt, expression, nj_llm)
        except Exception as e:
            errors[f"test_{i}"] = e
        else:
            try:
                hard_results[f"test_{i}"] = outputs[f"test_{i}"] == answer
            except Exception as e:
                errors[f"test_{i}"] = e

    return outputs, errors, hard_results, llm_judge_outputs


if __name__ == "__main__":
    results, errors, hard_results, usage = run("openai/gpt-4.1")
    print(results)
    print(hard_results)
    print(errors)
    print(usage)
