from typing import Any, Dict, List, Tuple

from pydantic import BaseModel, Field

from nightjar import nj_llm_factory


class FilterEvenNumbersLLMResult(BaseModel):
    filtered_numbers: List[int] = Field(default_factory=list)


def main(numbers: List[int], prop, nj_llm) -> List[int]:
    result: FilterEvenNumbersLLMResult = nj_llm(
        "Filter out the numbers from the <numbers> that don't satisfy <prop>"
        "Return the list in the 'filtered_numbers' field.\n"
        f"<numbers>{numbers}</numbers>"
        f"<prop>{prop}</prop>",
        output_format=FilterEvenNumbersLLMResult,
    )

    return result.filtered_numbers


#### Tests ####


def run(
    model_name: str,
) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Any], Dict[str, bool], Dict[str, str]]:
    nj_llm, usage = nj_llm_factory(model_name, max_calls=100)
    inps = [
        ([-1234, 2017, 1345134, 1802, 500, 2025], "Number must be a reasonable birth year", [2017, 2025]),
        ([42, 911, 128, 1337, 572], "cultural references", [42, 911, 1337]),
        ([1234, 1111, 2748, 7777, 8392], "might be a password", [1234, 1111, 7777]),
    ]
    outputs = {}
    errors = {}
    hard_results = {}

    for i, (inp, prop, expected) in enumerate(inps):
        outputs[f"test_{i}"] = None
        errors[f"test_{i}"] = None
        hard_results[f"test_{i}"] = False

        try:
            outputs[f"test_{i}"] = main(inp, prop, nj_llm)
        except Exception as e:
            errors[f"test_{i}"] = e
        else:
            try:
                hard_results[f"test_{i}"] = outputs[f"test_{i}"] == expected
            except Exception as e:
                errors[f"test_{i}"] = e

    return outputs, errors, hard_results, usage


if __name__ == "__main__":
    results, errors, hard_results, _ = run()
    print(results)
    print(hard_results)
    print(errors)
