import argparse
import os
import pathlib

from datasets import concatenate_datasets
from dotenv import load_dotenv

from strong_reject.load_datasets import load_strongreject
from strong_reject.evaluate import evaluate_dataset
from strong_reject.generate import generate_to_dataset
from strong_reject.jailbreaks import (
    apply_jailbreaks_to_dataset,
    decode_dataset,
    registered_jailbreaks,
)

DATA_DIR = pathlib.Path(__file__).parent.parent.resolve() / "data"

parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true")


def main():
    if os.getenv("TESTING"):
        dataset = load_strongreject().shuffle().select(range(2))
        jailbreaks = "base64", "translation_hmong", "pair"
        models = (
            "meta-llama/Meta-Llama-3.1-70B-Instruct",
            "cognitivecomputations/dolphin-2.6-mixtral-8x7b",
            "gpt-3.5-turbo",
        )
        max_tokens = 100
    else:
        dataset = load_strongreject()
        jailbreaks = registered_jailbreaks
        models = (
            "meta-llama/Meta-Llama-3.1-70B-Instruct",
            "cognitivecomputations/dolphin-2.6-mixtral-8x7b",
            "gpt-4o-mini",
            "gpt-3.5-turbo",
        )
        max_tokens = 384

    dataset = apply_jailbreaks_to_dataset(dataset, jailbreaks)
    dataset.to_json(str(DATA_DIR / "interim" / "jailbreaks.json"))
    datasets = []
    for model in models:
        model_dataset = generate_to_dataset(
            dataset, models=[model], target_column="jailbroken_prompt", max_tokens=max_tokens
        )
        model_dataset.to_json(str(DATA_DIR / "interim" / f"responses_{model}.json"))
        model_dataset = evaluate_dataset(model_dataset, ["strongreject_rubric"])
        model_dataset.to_json(str(DATA_DIR / "interim" / f"evals_{model}.json"))
        datasets.append(model_dataset)

    return concatenate_datasets(datasets)


if __name__ == "__main__":
    load_dotenv()
    filename = "evals.json"
    args = parser.parse_args()
    if args.test:
        os.environ["TESTING"] = "true"
        filename = "evals_test.json"

    dataset = main()
    dataset.to_json(str(DATA_DIR / "interim" / filename))
