from typing import Optional, Tuple, List
from dataclasses import dataclass, field

import os
import torch
import pandas as pd
from transformers import HfArgumentParser
from datasets import Dataset

from ..util import save_df_dataset
from ..dataset.pku import get_pku_by_safety
from ..dataset.hh_rlhf import get_hh_rlhf_by_safety

from vllm import SamplingParams
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
from saladbench.evaluator import MDJudgeEvaluator

from ..dataset.salad import get_balanced_salad_dataset
from ..util import generate_responses
from ..dataset.pku import pku_format


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """
    dataset: Optional[str] = field(default="pku", metadata={"help": "The dataset to clean (pku, hh-rlhf). Default=pku"})
    data_dir: str = field(default="data_cache", metadata={"help": "Directory for local datasets."})
    save_name: Optional[str] = field(default=None, metadata={"help": "Name to save created dataset."})
    evaluator: Optional[str] = field(default='salad', metadata={"help": "The evaluator to use for safety evaluation"})
    margin_threshold: Optional[float] = field(default=0.05, metadata={"help": "Margin threshold to remove"})
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "Only train on several samples"})


def safety_eval_labels(prompts, responses, evaluator='salad') -> Tuple[List, List]:
    if evaluator == 'salad':
        model_ckpt = "OpenSafetyLab/MD-Judge-v0.1"
        md_judge = MDJudgeEvaluator(
            model_ckpt, use_vllm=True, use_flash_attn=True
        )
        eval_labels, eval_clses = md_judge.eval_unsafe_labels(prompts, responses)
        eval_labels = [["SAFE", "UNSAFE", "Exception"][label] for label in eval_labels]

        # Reset computing resource
        destroy_model_parallel()
        del md_judge
        torch.cuda.synchronize()

        return eval_labels, eval_clses
    elif evaluator == 'llama':
        # TODO: Implement
        return [], []
    else:
        raise ValueError("Invalid evaluator. Must be 'salad' or 'llama'.")


def safety_eval_probs(prompts, responses, evaluator='salad') -> List:
    if evaluator == 'salad':
        model_ckpt = "OpenSafetyLab/MD-Judge-v0.1"
        md_judge = MDJudgeEvaluator(
            model_ckpt, use_vllm=True, use_flash_attn=True
        )
        unsafe_probs = md_judge.eval_unsafe_probs(prompts, responses)
        safe_probs = [1. - unsafe_prob for unsafe_prob in unsafe_probs]

        # Reset computing resource
        destroy_model_parallel()
        del md_judge
        torch.cuda.synchronize()

        return safe_probs
    elif evaluator == 'llama':
        # TODO: Implement
        return []
    else:
        raise ValueError("Invalid evaluator. Must be 'salad' or 'llama'.")


def salad_generate(model_name_or_path, category=None, sanity_check=False):
    salad_dataset = get_balanced_salad_dataset(
        split='train',
        cache_dir='data_cache/balanced_salad')
    salad_pd = pd.DataFrame(salad_dataset)
    if category is not None:
        salad_pd = salad_pd[salad_pd['2-category'] == category]
    salad_pd['prompt'] = salad_pd['question'].map(lambda question: pku_format(question))
    if sanity_check:
        salad_pd = salad_pd.sample(100, replace=True)
        salad_pd = salad_pd.reset_index(drop=True)

    # Generation
    sampling_params = SamplingParams(temperature=0, max_tokens=256)
    salad_pd['response'] = generate_responses(model_name_or_path, salad_pd['prompt'].to_list(), sampling_params)

    return salad_pd[['qid', 'question', 'response', '1-category', '2-category', '3-category', 'source']]


def get_unique_qa(pref_dataset: Dataset) -> pd.DataFrame:
    """
    Extract unique QA pairs from the dataset.
    """
    data_list = []
    for row in pref_dataset:
        data_list.append((row['prompt'], row['salad_prompt'], row['chosen']))
        data_list.append((row['prompt'], row['salad_prompt'], row['rejected']))

    unique_data_list = sorted(set(data_list))

    return pd.DataFrame({
        'prompt': [data[0] for data in unique_data_list],
        'salad_prompt': [data[1] for data in unique_data_list],
        'response': [data[2] for data in unique_data_list],
    })


def eval_qa_dataset(pref_dataset: Dataset, evaluator: str) -> pd.DataFrame:
    """
    Evaluate the QA dataset using the specified evaluator.
    """
    qa_df = get_unique_qa(pref_dataset)
    qa_df['eval_label'], qa_df['eval_category'] = safety_eval_labels(
        qa_df['salad_prompt'].to_list(),
        qa_df['response'].to_list(),
        evaluator=evaluator
    )
    qa_df['eval_safety_probs'] = safety_eval_probs(
        qa_df['salad_prompt'].to_list(),
        qa_df['response'].to_list(),
        evaluator=evaluator
    )
    return qa_df


def merge_and_evaluate(pref_df: pd.DataFrame, qa_df: pd.DataFrame, margin_threshold: float) -> pd.DataFrame:
    """
    Merge preference and evaluation dataframes and evaluate based on safety probabilities.
    """
    pref_merge_df = pd.merge(pref_df, qa_df, left_on=['prompt', 'chosen'], right_on=['prompt', 'response'])
    pref_merge_df = pref_merge_df.drop(['response'], axis=1)
    pref_merge_df = pref_merge_df.rename(
        {'eval_category': 'chosen_eval_category', 'eval_label': 'chosen_eval_label', 'eval_safety_probs': 'chosen_eval_safety_probs'},
        axis=1
    )

    pref_merge_df = pd.merge(pref_merge_df, qa_df, left_on=['prompt', 'rejected'], right_on=['prompt', 'response'])
    pref_merge_df = pref_merge_df.drop(['response'], axis=1)
    pref_merge_df = pref_merge_df.rename(
        {'eval_category': 'rejected_eval_category', 'eval_label': 'rejected_eval_label', 'eval_safety_probs': 'rejected_eval_safety_probs'},
        axis=1
    )

    pref_merge_df['data_fit'] = pref_merge_df.apply(
        lambda row: 'Ambiguous' if abs(row['chosen_eval_safety_probs'] - row['rejected_eval_safety_probs']) <= margin_threshold
        else ('Chosen > Rejected' if row['chosen_eval_safety_probs'] > row['rejected_eval_safety_probs'] else 'Rejected > Chosen'),
        axis=1
    )

    return pref_merge_df


def save_datasets(pref_merge_df: pd.DataFrame, data_get_func, script_args: ScriptArguments, save_name: str):
    """
    Save the datasets based on the evaluation results.
    """
    test_pref_dataset = data_get_func(
        split='test',
        cache_dir=f'{script_args.data_dir}/{script_args.dataset}-safety',
        sanity_check=script_args.sanity_check
    )

    green_df = pref_merge_df[pref_merge_df['data_fit'] != 'Rejected > Chosen']
    save_df_dataset(green_df, f"{script_args.data_dir}/{save_name}-green/train")
    test_pref_dataset.save_to_disk(f"{script_args.data_dir}/{save_name}-green/test")

    red_df = pref_merge_df[pref_merge_df['data_fit'] == 'Rejected > Chosen']
    save_df_dataset(red_df, f"{script_args.data_dir}/{save_name}-red/train")
    test_pref_dataset.save_to_disk(f"{script_args.data_dir}/{save_name}-red/test")

    print('#samples of the original dataset:', len(pref_merge_df))
    print('#samples of the dataset after cleaning:', len(green_df))


def main():
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    save_name = script_args.save_name or f"{script_args.dataset}-{script_args.evaluator}-{script_args.margin_threshold}"

    if script_args.dataset not in ['pku', 'hh-rlhf']:
        raise ValueError("Invalid dataset. Must be 'pku' or 'hh-rlhf'.")

    data_get_func = get_pku_by_safety if script_args.dataset == "pku" else get_hh_rlhf_by_safety
    pref_dataset = data_get_func(
        split='train',
        cache_dir=f'{script_args.data_dir}/{script_args.dataset}-safety',
        sanity_check=script_args.sanity_check
    )

    csv_path = f"{script_args.data_dir}/{script_args.dataset}_qa_{script_args.evaluator}.csv"
    if not os.path.exists(csv_path):
        qa_df = eval_qa_dataset(pref_dataset, evaluator=script_args.evaluator)
        qa_df.to_csv(csv_path)
    else:
        qa_df = pd.read_csv(csv_path, index_col=0)

    pref_df = pd.DataFrame(pref_dataset)
    pref_merge_df = merge_and_evaluate(pref_df, qa_df, script_args.margin_threshold)

    save_datasets(pref_merge_df, data_get_func, script_args, save_name)


if __name__ == "__main__":
    main()
