

import torch
import llm_utils.inf_llm as inf_llm


sample_evals = [
    {
        'raw_output': 'So the union of these intervals is:\n\n$$\nx \\in [2.5, 3.5)\n$$\n\n',
        'label': '$\\boxed{\\left[ \\frac{5}{2}, \\frac{7}{2} \\right)}$',
        'clean_output': {'clean_pred': '[2.5, 3.5)', 'isCorr': 1}
    },
    {
        'raw_output': 'The final answer is \n\n$$\nx \\in [2.7, 3.4)\n$$\n\n',
        'label': '$\\boxed{\\left[ \\frac{5}{2}, \\frac{7}{2} \\right)}$',
        'clean_output': {'clean_pred': '[2.7, 3.4)', 'isCorr': 0}
    }
]


autoEval_messages = []
for eval_item in sample_evals:
    autoEval_messages += [
        {'role': 'user', 'content': f"The following output text is generated by an LLM: {eval_item['raw_output']}.\
         The ground truth label is {eval_item['label']}.\
        Please evaluate whether the LLM answers correctly or not for the problem.\
        Please return a json dict whose elements are dicts with keys 'clean_pred' and 'isCorr' (1 or 0, where 1 for True and 0 for False).\
        Note that both 'clean_pred' should be the parsed, neat and clean results (without reasoning process) rather than long sentences.\
        Please generate the json dict only."},
        {'role': 'assistant', 'content': f"{eval_item['clean_output']}"},
    ]


def eval_output(test_item, raw_outputs, cur_model, cur_tokenizer):

    if test_item['label'] != '': test_label = test_item['label']
    else: test_label = test_item['cot_content']
    _messages = autoEval_messages + [
            {'role': 'user', 'content': f"The following output text is generated by an LLM: {raw_outputs}.\
            The ground truth label is {test_label}.\
            Please evaluate whether the LLM answers correctly or not for the problem.\
            Please return a json dict whose elements are dicts with keys 'clean_pred' and 'isCorr' (1 or 0, where 1 for True and 0 for False).\
            Note that both 'clean_pred' should be the parsed, neat and clean results (without reasoning process) rather than long sentences.\
            Please generate the json dict only."},
    ]

    tokenized_inputs = inf_llm.get_llm_inputs(cur_model, cur_tokenizer, _messages)
    with torch.no_grad(): think_outputs, raw_predItem = inf_llm.get_llm_outputs(cur_model, cur_tokenizer, tokenized_inputs, think_mode='short-factual')

    return raw_predItem