"""
Contains evaluation utilities for pytorch-based rewriting methods.
To use, simply call `compute_rewrite_quality_counterfact` with the
appropriate arguments, which returns a dictionary containing them.
"""

import typing

from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsets import AttributeSnippets
from util.generate import generate_fast

from experiments.py.eval_utils_akew import apply_template, compute_mmlu

def compute_unke(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
    snips: AttributeSnippets,
    vec: TfidfVectorizer,
) -> typing.Dict:
    QA_prompt = [apply_template(tok, record['few_shot']+[c], add_qa_prompt=False) for c in record['eval_data']]
    gen_cnt = 50 if 'qwen' in tok.name_or_path.lower() else 20
    QA_generation = generate_fast(
        model,
        tok,
        QA_prompt,
        top_k=1,
        max_out_len=tok(QA_prompt, return_tensors='pt', padding=True)['input_ids'].size(1) + gen_cnt,
        normalize="NFC"
    )
    generation_result = [g.replace(p, '') for g, p in zip(QA_generation, QA_prompt)]
    mmlu_result = compute_mmlu(model, tok, record)

    return {
        'generation_result': generation_result,
        'neighborhood_result': -1,
        'mmlu_result': mmlu_result
    }
