import argparse
import os
import pathlib

from datasets import load_dataset
from dotenv import load_dotenv

from strong_reject.evaluate import evaluate_dataset, registered_evaluators

DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / "data"

parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true")
parser.add_argument("--evaluator", default=None)


def main(evaluator=None, batch_size=4):
    dataset = load_dataset("csv", data_files=str(DATA_DIR / "processed" / "labelbox.csv"))["train"]
    if os.getenv("TESTING"):
        dataset = dataset.select(range(4))
        batch_size = 2

    evaluators = registered_evaluators if evaluator is None else [evaluator]
    dataset = evaluate_dataset(dataset, evaluators, batch_size=batch_size)
    return dataset


if __name__ == "__main__":
    load_dotenv()
    filename = "labelbox_evals.csv"
    args = parser.parse_args()
    if args.test:
        os.environ["TESTING"] = "true"
        filename = "labelbox_evals_test.csv"

    dataset = main(args.evaluator).remove_columns("response")
    dataset.to_csv(str(DATA_DIR / "interim" / filename))
