from metrics.verbmem import eval as eval_verbmem
from utils import load_model, load_tokenizer, write_csv, read_json, write_json
from constants import SUPPORTED_METRICS, LLAMA_DIR

import os
from typing import List, Dict, Literal
from pandas import DataFrame


def eval_model(
    model,
    tokenizer,
    verbmem_agg_key: str = 'mean_rougeL',
    verbmem_max_new_tokens: int = 128,
    verbmem_forget_file = None,
    temp_dir: str | None = None,
) -> Dict[str, float]:

    out = {}
    assert os.path.exists(verbmem_forget_file), f"File not found: {verbmem_forget_file}"
    # verbmem_f
    data = read_json(verbmem_forget_file)
    agg, log = eval_verbmem(
        prompts=[d['prompt'] for d in data],
        gts=[d['gt'] for d in data],
        model=model, tokenizer=tokenizer,
        max_new_tokens=verbmem_max_new_tokens
    )
    if temp_dir is not None:
        write_json(agg, os.path.join(temp_dir, "verbmem_f/agg.json"))
        write_json(log, os.path.join(temp_dir, "verbmem_f/log.json"))
    out['verbmem_f'] = agg[verbmem_agg_key] * 100

    return out


def load_then_eval_models(
    model_dirs: List[str],
    names: List[str],
    tokenizer_dir,
    out_file,
    verbmem_forget_file,
    temp_dir: str = "temp"
) -> DataFrame:
    # Argument sanity check
    if not model_dirs:
        raise ValueError(f"`model_dirs` should be non-empty.")
    if len(model_dirs) != len(names):
        raise ValueError(f"`model_dirs` and `names` should equal in length.")
    if out_file is not None and not out_file.endswith('.csv'):
        raise ValueError(f"The file extension of `out_file` should be '.csv'.")

    # Run evaluation
    out = []
    for model_dir, name in zip(model_dirs, names):
        model = load_model(model_dir)
        tokenizer = load_tokenizer(tokenizer_dir)
        res = eval_model(
            model, tokenizer,
            temp_dir=os.path.join(temp_dir, name),
            verbmem_forget_file=verbmem_forget_file
        )
        out.append({'name': name} | res)
        if out_file is not None: write_csv(out, out_file)
    return DataFrame(out)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dirs', type=str, nargs='+', default=[])
    parser.add_argument('--names', type=str, nargs='+', default=[])
    parser.add_argument('--tokenizer_dir', type=str, default=LLAMA_DIR)
    parser.add_argument('--out_file', type=str, required=True)
    parser.add_argument('--verbmem_forget_file', type=str, required=True)
    args = parser.parse_args()
    load_then_eval_models(args.model_dirs, args.names, args.tokenizer_dir, args.out_file, args.verbmem_forget_file)
