import os
import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from saladbench.evaluator import MDJudgeEvaluator
from transformers import AutoTokenizer, AutoModelForCausalLM
import alpaca_eval.annotators as eval_annotators
from util import EVALUATORS, TASKS, save_json, load_json

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
BASE_DIR = os.environ['BASE_DIR']


def safety_eval_labels_mdjudge(md_judge, prompts, responses):
    """Evaluate safety labels using MD-Judge."""
    eval_labels, eval_clses = md_judge.eval_unsafe_labels(prompts, responses)
    eval_labels = [["SAFE", "UNSAFE", "Exception"][label] for label in eval_labels]
    return eval_labels


def safety_eval_labels_llamaguard(llama_judge, tokenizer, prompts, responses, chunk_size=20):
    """Evaluate safety labels using LlamaGuard."""
    eval_labels = []
    for i in range(0, len(prompts), chunk_size):
        chunk_prompts = prompts[i:i + chunk_size]
        chunk_responses = responses[i: i + chunk_size]
        chunk_messages = [
            [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": response}
            ]
            for prompt, response in zip(chunk_prompts, chunk_responses)
        ]
        chunk_texts = [
            tokenizer.apply_chat_template(chat, tokenize=False)
            for chat in chunk_messages
        ]
        inputs = tokenizer(chunk_texts, return_tensors="pt", padding=True)
        inputs.input_ids = inputs.input_ids.to("cuda")
        generation_output = llama_judge.generate(
            **inputs,
            max_new_tokens=5,
            do_sample=False,
        )

        outputs = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
        outputs = [
            output.split('[/INST]')[1].split('\n')[0]
            .replace('unsafe', 'UNSAFE')
            .replace('safe', 'SAFE')
            for output in outputs
        ]
        eval_labels.extend(outputs)
    return eval_labels


def helpful_eval_gpt4(questions, responses):
    """Evaluate helpfulness using GPT-4."""
    reference_responses = load_json('SFT/biased_outputs.json')
    eval_df = pd.DataFrame({
        'instruction': questions,
        'output_1': responses,
        'output_2': reference_responses,
    })

    pairwise_annotator = eval_annotators.PairwiseAnnotator(
        annotators_config=f'{BASE_DIR}/config/alpaca_eval_gpt4',
        input_keys=("instruction",),
        output_keys=("output_1", "output_2"),
        seed=0,
    )
    output_eval_df = pairwise_annotator.annotate_pairs(eval_df)
    score = 0.
    for record in output_eval_df:
        if record['preference'] == 1:
            score += 1.
        elif record['preference'] == 2:
            score += 0.
        else:
            score += 0.5
    final_score = score / len(output_eval_df)
    return final_score


def main():
    # Load harmless prompts
    harmless_df = pd.read_json('../../asset/helpful_prompts.json')
    harmless_questions = list(harmless_df['prompt'])
    save_json(harmless_questions, 'harmless_questions.json')

    # Load SALAD-Bench prompts
    salad_dataset = load_from_disk('../../data_cache/balanced_salad/train')
    salad_df = pd.DataFrame(salad_dataset)
    harmful_questions = salad_df['question'].tolist()
    harmful_categories = salad_df['2-category'].tolist()
    save_json(
        list(zip(
            range(len(salad_df)),
            harmful_questions,
            harmful_categories
        )),
        'harmful_questions.json'
    )

    # Initialize MD-Judge
    if "mdjudge" in EVALUATORS:
        model_ckpt = "OpenSafetyLab/MD-Judge-v0.1"
        md_judge = MDJudgeEvaluator(
            model_ckpt, use_vllm=True, use_flash_attn=True
        )

    # Initialize LlamaGuard
    if "llamaguard" in EVALUATORS:
        model_id = "meta-llama/Meta-Llama-Guard-2-8B"
        dtype = torch.bfloat16
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.padding_side = 'left'
        tokenizer.pad_token = tokenizer.eos_token
        llama_judge = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
        llama_judge.generation_config.pad_token_id = tokenizer.pad_token_id

    # Set up tasks (output_dirs) to be done
    output_dirs = [task[1] for task in TASKS]

    for output_dir in output_dirs:
        save_path = f'{output_dir}/scores_gpt4.json'

        # Load existing scores
        if os.path.exists(save_path):
            scores = load_json(save_path)
        else:
            scores = {}

        # Helpful evaluation
        if os.path.exists(f'{output_dir}/debiased_outputs.json'):
            outputs_path = f'{output_dir}/debiased_outputs.json'
        else:
            outputs_path = f'{output_dir}/biased_outputs.json'

        outputs = load_json(outputs_path)

        # GPT-4 helpfulness evaluation
        scores['helpful'] = helpful_eval_gpt4(harmless_questions, outputs)

        # Safety evaluation
        if os.path.exists(f'{output_dir}/salad_debiased_outputs.json'):
            outputs_path = f'{output_dir}/salad_debiased_outputs.json'
        else:
            outputs_path = f'{output_dir}/salad_biased_outputs.json'
        outputs = load_json(outputs_path)
        salad_df['response'] = outputs
        categories = sorted(salad_df['2-category'].dropna().unique().tolist())

        for category in categories:
            cat_salad_df = salad_df[salad_df['2-category'] == category].copy()
            questions = cat_salad_df['question'].tolist()
            responses = cat_salad_df['response'].tolist()

            if "mdjudge" in EVALUATORS:
                if f'mdjudge_{category}' not in scores:
                    eval_result = safety_eval_labels_mdjudge(md_judge, questions, responses)
                    scores[f'mdjudge_{category}'] = (np.array(eval_result) == 'SAFE').mean()

            if "llamaguard" in EVALUATORS:
                if f'llamaguard_{category}' not in scores:
                    eval_result = safety_eval_labels_llamaguard(llama_judge, tokenizer, questions, responses)
                    scores[f'llamaguard_{category}'] = (np.array(eval_result) == 'SAFE').mean()

        save_json(scores, save_path)


if __name__ == "__main__":
    main()
