import time

import fire

from evaluation.models import load_generative_model
from evaluation.schemas import OpenEndedTaskBase, ClassificationTaskBase, JailbreakBase, VllmGeneratorModel
from src.classifier_models.loader import load_classifier_model
from evaluation.tasks import load_evaluation_tasks, EvalMode, wrap_task_with_jailbreak
from evaluation.utils import maybe_load_evaluation_report, save_evaluation_report


def classifiers(model_name: str,
                tasks: str,
                report_output_path: str,
                save_individual_results_path: str | None = None,
                override_model_path: str | None = None,
                override_existing_report: bool = False,
                ):
    """
    Evaluate a model using command line arguments.
    Evaluation report will be saved into the json format, with the keys
    as the task names and the values as the dictionary of evaluation results per model.

    Example usage:
        python eval.py classifiers \
            --model_name ./finetuned_olmo_7b \
            --tasks ai2_prompt_safety,ai2_prompt_safety:vanilla \
            --report_output_path ./evaluation_report_olmo_classification.json

    Args:
        model_name: Name of the model to evaluate.
            It will be loading a safety classifier from oe-safety (if it is a classification task).

        tasks: List of tasks to evaluate the model on, split by comma.
            each task can be xxx or xxx:yyy where xxx is the task name and yyy is the task config file name.
            if yyy is not provided, it will be assumed as xxx/default.yaml.
            Then the yaml file for the task config will be loaded from tasks/classification/xxx/yyy.yaml.

        report_output_path: Path to save the evaluation report.

        save_individual_results_path: Path to save the individual results.

        override_model_path: Path to the model to override the model path.
            this is to replace the model path easily without implementing the new model in oe-safety.

        override_existing_report: Whether to override the existing report or not. If true, it will override the existing report.
    """
    start_time = time.time()
    model = load_classifier_model(model_name, local_model_path=override_model_path)
    if isinstance(tasks, tuple):
        tasks = list(tasks)
    elif isinstance(tasks, str):
        tasks = tasks.split(",")
    evaluation_tasks: list[ClassificationTaskBase] = load_evaluation_tasks(EvalMode.CLASSIFICATION, tasks)
    evaluation_report = maybe_load_evaluation_report(report_output_path)
    all_individual_results = maybe_load_evaluation_report(save_individual_results_path)

    for task_name, eval_task in zip(tasks, evaluation_tasks):
        if task_name in evaluation_report and not override_existing_report:
            print(f"Skipping {task_name} as it is already evaluated.")
            continue

        try:
            eval_task.validate_if_model_is_compatible(model)
            report, individual_results = eval_task.evaluate(model)
            evaluation_report[task_name] = report
            all_individual_results[task_name] = individual_results
        except AssertionError as e:
            print(f"Skipping {task_name} due to the error: {e}")

    save_evaluation_report(report_output_path, evaluation_report)
    if save_individual_results_path:
        save_evaluation_report(save_individual_results_path, all_individual_results)
    print(f" *** DONE! Evaluation report is saved to {report_output_path}. "
          f"Spent {time.time() - start_time:.2f} seconds.")


def generators(model_name_or_path: str,
               tasks: str,
               model_input_template_path_or_name: str,
               report_output_path: str,
               save_individual_results_path: str | None = None,
               override_existing_report: bool = False,
               use_vllm: bool = False,
               batch_size: int | None = None,
               filter_classifier_name: str | None = None,
               filter_model_override_path: str | None = None,
               hf_revision: str | None = None,
               ):
    """
    Evaluate a model using command line arguments.
    Evaluation report will be saved into the json format, with the keys
    as the task names and the values as the dictionary of evaluation results per model.

    Example usage:
        python eval.py generators \
            --model_name_or_path ./finetuned_olmo_7b \
            --model_input_template_path_or_name olmo \
            --tasks harmbench \
            --report_output_path ./evaluation_report_olmo_generation.json

    Args:
        model_name_or_path: Name or path of the model to evaluate.
            If it is a name, it will be loading a model from the Hugging Face model hub, or using the API-based models (e.g. GPT-4).
            If it is a path, it will be loaded from the local path.

        tasks: List of tasks to evaluate the model on, split by comma.
            each task can be xxx or xxx:yyy where xxx is the task name and yyy is the task config file name.
            if yyy is not provided, it will be assumed as xxx/default.yaml.
            Then the yaml file for the task config will be loaded from tasks/generation/xxx/yyy.yaml.

        model_input_template_path_or_name: Path to the .txt file, which contains the chat template for the model.
            If it is a name, you can load from src/templates.
            If it is a file, the file should contain the input template for the model, which will be used to generate the responses for the evaluation tasks.
            For example, if the task is a single-turn dialogue task, it should contain {instruction} in its template.

        filter_classifier_name: Name of the classifier to filter out harmful prompts.
            If it is provided, all the prompts will be classified, and harmful prompts will be filtered out using the classifier.
            Harmful prompts will be answered by a canned refusal response.

        filter_model_override_path: Path to load a local model for the filter classifier.

        report_output_path: Path to save the evaluation report.

        save_individual_results_path: Path to save the individual results.

        override_existing_report: Whether to override the existing report or not. If true, it will override the existing report.

        use_vllm: Whether to use VLLM for generation or not.

        batch_size: Batch size for generation.
            If you use VLLM, batch_size will be automatically decided by the model.
    """
    start_time = time.time()
    if isinstance(tasks, tuple):
        tasks = list(tasks)
    elif isinstance(tasks, str):
        tasks = tasks.split(",")

    model = load_generative_model(
        model_name_or_path,
        model_input_template_path_or_name,
        use_vllm,
        batch_size,
        filter_classifier_name=filter_classifier_name,
        filter_model_override_path=filter_model_override_path,
        hf_revision=hf_revision,
    )
    evaluation_tasks: list[OpenEndedTaskBase] = load_evaluation_tasks(EvalMode.GENERATION, tasks)
    evaluation_report = maybe_load_evaluation_report(report_output_path)
    all_individual_results = maybe_load_evaluation_report(save_individual_results_path)

    for task_name, eval_task in zip(tasks, evaluation_tasks):
        if task_name in evaluation_report and not override_existing_report:
            print(f"Skipping {task_name} as it is already evaluated.")
            continue

        eval_task.validate_if_model_is_compatible(model)
        report, individual_results = eval_task.evaluate(model)
        evaluation_report[task_name] = report
        all_individual_results[task_name] = individual_results

    save_evaluation_report(report_output_path, evaluation_report)
    if save_individual_results_path:
        save_evaluation_report(save_individual_results_path, all_individual_results)

    if isinstance(model, VllmGeneratorModel) and model.delete_model_after_job_finished:
        model.delete_model_from_file_system()

    print(f" *** DONE! Evaluation report is saved to {report_output_path}. "
          f"Spent {time.time() - start_time:.2f} seconds.")


def jailbreak(model_name_or_path: str,
              tasks: str,
              attacks: str,
              model_input_template_path_or_name: str,
              report_output_path: str,
              save_individual_results_path: str | None = None,
              batch_size: int | None = None,
              override_existing_report: bool = False,
              use_vllm: bool = False,
              ):
    """
    Evaluate a model using command line arguments.
    Evaluation report will be saved into the json format, with the keys
    as the task names and the values as the dictionary of evaluation results per model.

    Example usage:
        python eval.py jailbreak \
            --model_name_or_path ./finetuned_olmo_7b \
            --model_input_template_path_or_name olmo \
            --tasks harmbench \
            --attacks pair,wildjailbreak \
            --report_output_path ./evaluation_report_olmo_jailbreak.json

    Args:
        model_name_or_path: Name or path of the model to evaluate.
            If it is a name, it will be loading a model from the Hugging Face model hub, or using the API-based models (e.g. GPT-4).
            If it is a path, it will be loaded from the local path.

        tasks: List of tasks to evaluate the model on, split by comma.
            each task can be xxx or xxx:yyy where xxx is the task name and yyy is the task config file name.
            if yyy is not provided, it will be assumed as xxx/default.yaml.
            Then the yaml file for the task config will be loaded from tasks/generation/xxx/yyy.yaml.

        attacks: List of attacks to evaluate the model on, split by comma.

        model_input_template_path_or_name: Path to the .txt file, which contains the chat template for the model.
            If it is a name, you can load from src/templates.
            If it is a file, the file should contain the input template for the model, which will be used to generate the responses for the evaluation tasks.
            For example, if the task is a single-turn dialogue task, it should contain {instruction} in its template.

        report_output_path: Path to save the evaluation report.

        save_individual_results_path: Path to save the individual results.

        batch_size: Batch size for generation.
            If you use VLLM, batch_size will be automatically decided by the model.

        override_existing_report: Whether to override the existing report or not. If true, it will override the existing report.

        use_vllm: Whether to use VLLM for generation or not.
    """
    start_time = time.time()
    if isinstance(tasks, tuple):
        tasks = list(tasks)
    elif isinstance(tasks, str):
        tasks = tasks.split(",")
    attacks = attacks.split(",")
    model = load_generative_model(model_name_or_path, model_input_template_path_or_name, use_vllm, batch_size)
    evaluation_tasks: list[OpenEndedTaskBase] = load_evaluation_tasks(EvalMode.GENERATION, tasks)
    evaluation_report = maybe_load_evaluation_report(report_output_path)
    all_individual_results = maybe_load_evaluation_report(save_individual_results_path)

    for task_name, eval_task in zip(tasks, evaluation_tasks):
        eval_tasks_with_attack: list[JailbreakBase] = [wrap_task_with_jailbreak(eval_task, attack) for attack in attacks]
        task_names_with_attack = [f"{task_name}/{attack}" for attack in attacks]
        for task_name_with_attack, eval_task_with_attack in zip(task_names_with_attack, eval_tasks_with_attack):
            if task_name_with_attack in evaluation_report and not override_existing_report:
                print(f"Skipping {task_name_with_attack} as it is already evaluated.")
                continue

            eval_task_with_attack.task.validate_if_model_is_compatible(model)
            report, individual_results = eval_task_with_attack.evaluate(model)
            evaluation_report[task_name_with_attack] = report
            all_individual_results[task_name_with_attack] = individual_results

    save_evaluation_report(report_output_path, evaluation_report)
    if save_individual_results_path:
        save_evaluation_report(save_individual_results_path, all_individual_results)

    if isinstance(model, VllmGeneratorModel) and getattr(model, "delete_model_after_job_finished", False):
        model.delete_model_from_file_system()

    print(f" *** DONE! Evaluation report is saved to {report_output_path}. "
          f"Spent {time.time() - start_time:.2f} seconds.")


ALL_FUNCTIONS = {
    "classifiers": classifiers,
    "generators": generators,
    "jailbreak": jailbreak,
}

if __name__ == "__main__":
    fire.Fire(ALL_FUNCTIONS)
