from pydantic import BaseModel, Field

from nightjar import nj_llm_factory


class MaskResult(BaseModel):
    """Structured output for masked text."""

    masked_text: str = Field(
        description="The input text with sensitive information replaced by asterisks of equal length."
    )


def main(text: str, nj_llm) -> str:
    """
    Mask all occurrences of sensitive information in `text` by replacing them with asterisks
    of the same length as the characters being replaced, and return the masked text.
    """
    # Convert the <natural> block to a structured Nightjar LLM call
    result: MaskResult = nj_llm(
        "Mask all occurrences of sensitive information found in the <text> by replacing them "
        "with asterisks of the same length as the characters being replaced. Then store it in "
        f"`masked_text`.\n<text>{text}</text>",
        output_format=MaskResult,
    )

    # Interpret the structured output
    return result.masked_text


#### Tests ####

from typing import Any, Dict, List, Tuple


def run(
    model_name: str,
) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Any], Dict[str, bool], Dict[str, str]]:
    nj_llm, usage = nj_llm_factory(model_name, max_calls=100)
    inputs_outputs = [
        ("my email is alice@gmail.com", "my email is ***************"),
        ("call me at 101-456-8099", "call me at ************"),
    ]
    outputs = {}
    errors = {}
    hard_results = {}

    for i, (input_value, expected_output) in enumerate(inputs_outputs):
        outputs[f"test_{i}"] = None
        errors[f"test_{i}"] = None
        hard_results[f"test_{i}"] = False

        try:
            res = main(input_value, nj_llm)
            outputs[f"test_{i}"] = res
        except Exception as e:
            errors[f"test_{i}"] = e
        else:
            try:
                hard_results[f"test_{i}"] = res == expected_output
            except Exception as e:
                errors[f"test_{i}"] = e

    return outputs, errors, hard_results, usage


if __name__ == "__main__":
    results, errors, hard_results, _ = run()
    print(results)
    print(hard_results)
    print(errors)
