import os
import pandas as pd
from datasets import load_dataset
from saladbench.evaluator import MDJudgeEvaluator


def load_and_prepare_datasets():
    """Load and prepare BeaverTails and PKU-SafeRLHF-30K datasets."""
    beaver_dataset = load_dataset('PKU-Alignment/BeaverTails', split='30k_train')
    rlhf_dataset = load_dataset('PKU-Alignment/PKU-SafeRLHF-30K', split='train')
    beaver_df = pd.DataFrame(beaver_dataset)
    rlhf_df = pd.DataFrame(rlhf_dataset)

    rlhf_df['chosen'] = rlhf_df.apply(lambda row: row[f'response_{row["safer_response_id"]}'], axis=1)
    rlhf_df['rejected'] = rlhf_df.apply(lambda row: row[f'response_{1 - row["safer_response_id"]}'], axis=1)

    return beaver_df, rlhf_df


def evaluate_beaver_data(beaver_df):
    """Evaluate the BeaverTails dataset using MD-Judge and save results to CSV."""
    model_ckpt = "OpenSafetyLab/MD-Judge-v0.1"
    md_judge = MDJudgeEvaluator(model_ckpt, use_vllm=True, use_flash_attn=True)

    eval_labels, eval_clses = md_judge.eval_unsafe_labels(beaver_df['prompt'].to_list(), beaver_df['response'].to_list())
    unsafe_probs = md_judge.eval_unsafe_probs(beaver_df['prompt'].to_list(), beaver_df['response'].to_list())

    beaver_df['salad_category'] = eval_clses
    beaver_df['salad_eval'] = [["SAFE", "UNSAFE", "Exception"][e] for e in eval_labels]
    beaver_df['salad_unsafe_probs'] = unsafe_probs

    return beaver_df


def merge_datasets(beaver_df, rlhf_df):
    """Merge BeaverTails and RLHF datasets and save the merged dataframe to CSV."""
    beaver_df_cats = pd.get_dummies(beaver_df['salad_category'], prefix='category')
    beaver_df = pd.concat([beaver_df, beaver_df_cats], axis=1)

    rlhf_merge_df = pd.merge(rlhf_df, beaver_df, left_on=['prompt', 'response_0'], right_on=['prompt', 'response'])
    rlhf_merge_df = rlhf_merge_df.drop(['response', 'is_safe'], axis=1)
    rlhf_merge_df = rlhf_merge_df.rename({
        'category': 'category_0',
        'salad_category': 'salad_category_0',
        'salad_eval': 'salad_eval_0',
        'salad_unsafe_probs': 'salad_unsafe_probs_0'
    }, axis=1)

    rlhf_merge_df = pd.merge(rlhf_merge_df, beaver_df, left_on=['prompt', 'response_1'], right_on=['prompt', 'response'])
    rlhf_merge_df = rlhf_merge_df.drop(['response', 'is_safe'], axis=1)
    rlhf_merge_df = rlhf_merge_df.rename({
        'category': 'category_1',
        'salad_category': 'salad_category_1',
        'salad_eval': 'salad_eval_1',
        'salad_unsafe_probs': 'salad_unsafe_probs_1'
    }, axis=1)

    rlhf_merge_df['chosen_safety_score'] = rlhf_merge_df.apply(lambda row: 1. - row[f'salad_unsafe_probs_{row["safer_response_id"]}'], axis=1)
    rlhf_merge_df['rejected_safety_score'] = rlhf_merge_df.apply(lambda row: 1. - row[f'salad_unsafe_probs_{1 - row["safer_response_id"]}'], axis=1)

    rlhf_merge_df = rlhf_merge_df[[
        'salad_eval_0',
        'salad_category_0',
        'salad_unsafe_probs_0',
        'salad_eval_1',
        'salad_category_1',
        'salad_unsafe_probs_1',
        'safer_response_id',
        'chosen_safety_score',
        'rejected_safety_score'
    ]]

    return rlhf_merge_df


def main():
    beaver_df, rlhf_df = load_and_prepare_datasets()

    if not os.path.exists('beaver_with_eval.csv'):
        beaver_df = evaluate_beaver_data(beaver_df)
        beaver_df.to_csv('beaver_with_eval.csv')

    if not os.path.exists('rlhf_with_eval.csv'):
        rlhf_merge_df = merge_datasets(beaver_df, rlhf_df)
        rlhf_merge_df.to_csv('rlhf_with_eval.csv')


if __name__ == "__main__":
    main()
