"""
"GEMBA" metric for use in notebooks and scripts.

Evals how good a translation is with respect to a reference.
"""

import re
import warnings

from uemt.models.language_model import LM

# Inspired by https://github.com/MicrosoftTranslator/GEMBA/blob/main/gemba/prompt.py
GEMBA_LIKE_PROMPT = (
    """
Score the following translation on a continuous scale 0 to 100 where score of zero means "no meaning preserved" and score of one hundred means "perfect meaning and grammar".

<HUMAN_REFERENCE>
{target}
</HUMAN_REFERENCE>

<MACHINE_TRANSLATION>
{translation}
</MACHINE_TRANSLATION>

Just output an integer score between 0 and 100, inclusive, and nothing else.
"""
).strip()


def gemba_da_ref(
    translation: str,
    target: str,
    model: LM,
    **kwargs,
) -> int:
    """LLM-based direct assessment w.r.t. reference (GEMBA-DA_ref-like).

    Returns an integer in [0, 100]. If the model returns multiple numbers,
    the first one is used and a warning is issued.
    """
    assert isinstance(target, str) and isinstance(translation, str)
    prompt = GEMBA_LIKE_PROMPT.format(
        target=target,
        translation=translation,
    )
    res = model.generate(prompt, **kwargs)
    if not isinstance(res, str):
        raise RuntimeError(f"Expected string response, got {type(res)}")
    numbers = re.findall(r"\d+", res)
    if not numbers:
        raise ValueError(f"No integer found in model response: {res!r}")
    if len(numbers) > 1:
        warnings.warn(f"Expected 1 number, got {len(numbers)}: {res!r}", stacklevel=2)
    return int(numbers[0])
