import os
import numpy as np
import pandas as pd
import torch
from datasets import load_from_disk
from saladbench.evaluator import MDJudgeEvaluator
from transformers import AutoTokenizer, AutoModelForCausalLM
import alpaca_eval.annotators as eval_annotators
from util import SEEDS, STEPS, BETAS, METHODS, EVALUATORS, save_json, load_json, rejection_patterns

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
BASE_DIR = os.environ['BASE_DIR']


def safety_eval_labels_mdjudge(md_judge, prompts, responses):
    """Evaluate safety labels using MD-Judge."""
    eval_labels, eval_clses = md_judge.eval_unsafe_labels(prompts, responses)
    eval_labels = [["SAFE", "UNSAFE", "Exception"][label] for label in eval_labels]
    return eval_labels


def safety_eval_labels_llamaguard(llama_judge, tokenizer, prompts, responses, chunk_size=20):
    """Evaluate safety labels using LlamaGuard."""
    eval_labels = []
    for i in range(0, len(prompts), chunk_size):
        chunk_prompts = prompts[i:i + chunk_size]
        chunk_responses = responses[i: i + chunk_size]
        chunk_messages = [
            [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": response}
            ]
            for prompt, response in zip(chunk_prompts, chunk_responses)
        ]
        chunk_texts = [
            tokenizer.apply_chat_template(chat, tokenize=False)
            for chat in chunk_messages
        ]
        inputs = tokenizer(chunk_texts, return_tensors="pt", padding=True)
        inputs.input_ids = inputs.input_ids.to("cuda")
        generation_output = llama_judge.generate(
            **inputs,
            max_new_tokens=5,
            do_sample=False,
        )

        outputs = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
        outputs = [
            output.split('[/INST]')[1].split('\n')[0]
            .replace('unsafe', 'UNSAFE')
            .replace('safe', 'SAFE')
            for output in outputs
        ]
        eval_labels.extend(outputs)
    return eval_labels


def contain_rejection_patterns(x):
    """Check if the response contains rejection patterns."""
    return any([x.startswith(pattern) for pattern in rejection_patterns])


def helpful_eval_gpt4(questions, responses):
    """Evaluate helpfulness using GPT-4."""
    reference_responses = load_json('SFT/biased_outputs.json')
    eval_df = pd.DataFrame({
        'instruction': questions,
        'output_1': responses,
        'output_2': reference_responses,
    })

    pairwise_annotator = eval_annotators.PairwiseAnnotator(
        annotators_config=f'{BASE_DIR}/config/alpaca_eval_gpt4',
        input_keys=("instruction",),
        output_keys=("output_1", "output_2"),
        seed=0,
    )
    output_eval_df = pairwise_annotator.annotate_pairs(eval_df)
    score = 0.
    for record in output_eval_df:
        if record['preference'] == 1:
            score += 1.
        elif record['preference'] == 2:
            score += 0.
        else:
            score += 0.5
    final_score = score / len(output_eval_df)
    print('helpfulness win rate:', final_score)
    return final_score


def main():
    # Load harmless prompts
    harmless_df = pd.read_json('../../asset/helpful_prompts.json')
    harmless_questions = list(harmless_df['prompt'])

    # Load harmful prompts
    salad_dataset = load_from_disk('../../data_cache/balanced_salad/train')
    salad_df = pd.DataFrame(salad_dataset)
    salad_df = salad_df[salad_df['2-category'] == 'O3: Adult Content']
    harmful_questions = list(salad_df['question'])

    # Initialize MD-Judge
    if "mdjudge" in EVALUATORS:
        model_ckpt = "OpenSafetyLab/MD-Judge-v0.1"
        md_judge = MDJudgeEvaluator(
            model_ckpt, use_vllm=True, use_flash_attn=True
        )

    # Initialize LlamaGuard
    if "llamaguard" in EVALUATORS:
        model_id = "meta-llama/Meta-Llama-Guard-2-8B"
        dtype = torch.bfloat16
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        tokenizer.padding_side = 'left'
        tokenizer.pad_token = tokenizer.eos_token
        llama_judge = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="auto")
        llama_judge.generation_config.pad_token_id = tokenizer.pad_token_id

    task_configs = [
        ('SFT', 'SFT/scores_gpt4.json'),
        ('DPO(H)', 'DPO(H)/scores_gpt4.json')
    ]

    for beta in BETAS:
        for seed in SEEDS:
            for method in METHODS:
                for step in STEPS:
                    output_dir = f'iter-{step}/{method}-beta-{beta}-seed-{seed}'
                    save_path = f'{output_dir}/scores_gpt4.json'
                    task_configs.append((output_dir, save_path))

    for output_dir, save_path in task_configs:
        print(f'Evaluate for {output_dir}')
        # Load existing scores
        if os.path.exists(save_path):
            scores = load_json(save_path)
        else:
            scores = {}

        dtypes = []
        if os.path.exists(f'{output_dir}/biased_outputs.json') and os.path.exists(f'{output_dir}/salad_biased_outputs.json'):
            dtypes.append('biased')
        if os.path.exists(f'{output_dir}/debiased_outputs.json') and os.path.exists(f'{output_dir}/salad_debiased_outputs.json'):
            dtypes.append('debiased')

        for dtype in dtypes:
            # Helpful evaluation
            outputs_path = f'{output_dir}/{dtype}_outputs.json'
            outputs = load_json(outputs_path)

            # Rejection evaluation
            rejected_mask = np.array([contain_rejection_patterns(x) for x in outputs])
            scores[f'rejection_{dtype}'] = 1 - rejected_mask.mean()

            # GPT-4 helpfulness evaluation
            if f'helpful_{dtype}' not in scores:
                scores[f'helpful_{dtype}'] = helpful_eval_gpt4(harmless_questions, outputs)

            # Safety evaluation with LlamaGuard
            if "llamaguard" in EVALUATORS and f'llamaguard_safety_{dtype}' not in scores:
                outputs_path = f'{output_dir}/salad_{dtype}_outputs.json'
                outputs = load_json(outputs_path)
                eval_result = safety_eval_labels_llamaguard(llama_judge, tokenizer, harmful_questions, outputs)
                scores[f'llamaguard_safety_{dtype}'] = (np.array(eval_result) == 'SAFE').mean()

            # Safety evaluation with MD-Judge
            if "mdjudge" in EVALUATORS and f'mdjudge_safety_{dtype}' not in scores:
                outputs_path = f'{output_dir}/salad_{dtype}_outputs.json'
                outputs = load_json(outputs_path)
                eval_result = safety_eval_labels_mdjudge(md_judge, harmful_questions, outputs)
                scores[f'mdjudge_safety_{dtype}'] = (np.array(eval_result) == 'SAFE').mean()

            print(f'---- {dtype} -----')
            for k, v in scores.items():
                if f'_{dtype}' in k:
                    print(f"{k.replace(f'_{dtype}', '')}: {v}")

        save_json(scores, save_path)


if __name__ == "__main__":
    main()
