import glob

import numpy as np
import pandas as pd


if __name__ == "__main__":  # noqa: C901

    print("\n========== JAILBREAK RESULTS ==========\n")

    jailbreak_reward_outputs_all = {}
    helpful_chosen_reward_outputs_all = {}
    helpful_rejected_reward_outputs_all = {}
    for reward_model_type in ["brex", "lexicase"]:
        train_set = "both"
        checkpoint_dir = glob.glob(
            f"data/reward_models/relabeled_hh_rlhf/{train_set}/base_*_peft_last_checkpoint/{reward_model_type}"
        )[0]

        hh_rlhf_evaluation = pd.read_json(
            f"{checkpoint_dir}/eval_results_hh.jsonl", lines=True
        )
        helpful_evaluation = hh_rlhf_evaluation[
            hh_rlhf_evaluation.data_subset == "helpful"
        ]

        jailbreak_evaluations = pd.read_json(
            f"{checkpoint_dir}/eval_results_jailbreak.jsonl", lines=True
        )

        reward_outputs_key = f"reward_outputs_{reward_model_type}"
        jailbreak_reward_outputs_all[reward_model_type] = np.array(
            jailbreak_evaluations[reward_outputs_key].tolist()
        )
        helpful_chosen_reward_outputs_all[reward_model_type] = np.array(
            helpful_evaluation.reward_output_chosen.tolist()
        )
        helpful_rejected_reward_outputs_all[reward_model_type] = np.array(
            helpful_evaluation.reward_output_rejected.tolist()
        )

    burn_in = 5000
    n_samples = 200000
    skip = 20

    print(f"--- Burn-in: {burn_in}, n_samples: {n_samples}, skip: {skip} ---")

    reward_model_type = "brex"
    print(f"--- Reward model type: {reward_model_type} ---")

    jailbreak_reward_outputs = jailbreak_reward_outputs_all[reward_model_type]
    helpful_chosen_reward_outputs = helpful_chosen_reward_outputs_all[reward_model_type]
    helpful_rejected_reward_outputs = helpful_rejected_reward_outputs_all[
        reward_model_type
    ]

    helpful_chosen_reward_outputs = helpful_chosen_reward_outputs[
        :, burn_in:n_samples:skip
    ]
    helpful_rejected_reward_outputs = helpful_rejected_reward_outputs[
        :, burn_in:n_samples:skip
    ]

    jailbreak_reward_outputs = jailbreak_reward_outputs[:, :, burn_in:n_samples:skip]

    def get_mean_reward(reward_outputs):
        return np.mean(reward_outputs, axis=-1)

    print(
        f"Jailbreak:",
        np.mean(
            get_mean_reward(jailbreak_reward_outputs[:, 1])
            >= get_mean_reward(jailbreak_reward_outputs[:, 0])
        ),
    )
    print(
        f"Accuracy:",
        np.mean(
            get_mean_reward(helpful_chosen_reward_outputs)
            >= get_mean_reward(helpful_rejected_reward_outputs)
        ),
    )
    print()

    alpha = 0.1
    print(f"--- Alpha: {alpha} ---")

    def get_reward_quantile(reward_outputs):
        return np.percentile(reward_outputs, alpha * 100, axis=-1)

    print(
        f"Risk-sensitive jailbreak:",
        np.mean(
            get_reward_quantile(jailbreak_reward_outputs[:, 1])
            >= get_reward_quantile(jailbreak_reward_outputs[:, 0])
        ),
    )
    print(
        f"Risk-sensitive accuracy:",
        np.mean(
            get_reward_quantile(helpful_chosen_reward_outputs)
            >= get_reward_quantile(helpful_rejected_reward_outputs)
        ),
    )
    print()

    burn_in = 99000
    n_samples = burn_in + 1000
    skip = 10

    print(f"--- Burn-in: {burn_in}, n_samples: {n_samples}, skip: {skip} ---")

    reward_model_type = "lexicase"
    print(f"--- Reward model type: {reward_model_type} ---")

    jailbreak_reward_outputs = jailbreak_reward_outputs_all[reward_model_type]
    helpful_chosen_reward_outputs = helpful_chosen_reward_outputs_all[reward_model_type]
    helpful_rejected_reward_outputs = helpful_rejected_reward_outputs_all[
        reward_model_type
    ]

    helpful_chosen_reward_outputs = helpful_chosen_reward_outputs[
        :, burn_in:n_samples:skip
    ]
    helpful_rejected_reward_outputs = helpful_rejected_reward_outputs[
        :, burn_in:n_samples:skip
    ]

    jailbreak_reward_outputs = jailbreak_reward_outputs[:, :, burn_in:n_samples:skip]

    def get_mean_reward(reward_outputs):
        return np.mean(reward_outputs, axis=-1)

    print(
        f"Jailbreak:",
        np.mean(
            get_mean_reward(jailbreak_reward_outputs[:, 1])
            >= get_mean_reward(jailbreak_reward_outputs[:, 0])
        ),
    )
    print(
        f"Accuracy:",
        np.mean(
            get_mean_reward(helpful_chosen_reward_outputs)
            >= get_mean_reward(helpful_rejected_reward_outputs)
        ),
    )
    print()

    alpha = 0.1
    print(f"--- Alpha: {alpha} ---")

    print(
        f"Risk-sensitive jailbreak:",
        np.mean(
            get_reward_quantile(jailbreak_reward_outputs[:, 1])
            >= get_reward_quantile(jailbreak_reward_outputs[:, 0])
        ),
    )
    print(
        f"Risk-sensitive accuracy:",
        np.mean(
            get_reward_quantile(helpful_chosen_reward_outputs)
            >= get_reward_quantile(helpful_rejected_reward_outputs)
        ),
    )
    print()
