from typing import Optional
from dataclasses import dataclass, field

import pandas as pd
from transformers import HfArgumentParser

from ..util import save_df_dataset
from ..dataset.pku import get_pku_by_safety
from ..dataset.hh_rlhf import get_hh_rlhf_by_safety
from ..constant import REJECTIVE_PATTERNS


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """
    dataset: Optional[str] = field(default="pku", metadata={"help": "The dataset to clean (pku, hh-rlhf). Default=pku"})
    data_dir: str = field(default="data_cache", metadata={"help": "Directory for local datasets."})
    save_name: Optional[str] = field(default=None, metadata={"help": "Name to save created dataset."})
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "Only train on several samples"})


def main():
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    save_name = script_args.save_name or f"{script_args.dataset}-safety"

    if script_args.dataset not in ['pku', 'hh-rlhf']:
        raise ValueError("Invalid dataset. Must be 'pku' or 'hh-rlhf'.")

    data_get_func = get_pku_by_safety if script_args.dataset == "pku" else get_hh_rlhf_by_safety
    pref_dataset = data_get_func(
        split='train',
        cache_dir=f'{script_args.data_dir}/{script_args.dataset}-safety',
        sanity_check=script_args.sanity_check
    )

    pref_df = pd.DataFrame(pref_dataset)
    print('#samples of the original dataset:', len(pref_df))

    # Filter out responses that start with any pattern in REJECTIVE_PATTERNS
    def not_starts_with_any_pattern(text, patterns):
        return not any(text.startswith(pattern) for pattern in patterns)

    no_rejective_df = pref_df[
        pref_df['chosen'].apply(lambda x: not_starts_with_any_pattern(x, REJECTIVE_PATTERNS)) &
        pref_df['rejected'].apply(lambda x: not_starts_with_any_pattern(x, REJECTIVE_PATTERNS))
    ]
    print('#samples after remove rejective patterns:', len(no_rejective_df))

    test_pref_dataset = data_get_func(
        split='test',
        cache_dir=f'{script_args.data_dir}/{script_args.dataset}-safety',
        sanity_check=script_args.sanity_check
    )

    save_df_dataset(no_rejective_df, f"{script_args.data_dir}/{save_name}-no-reject/train")
    test_pref_dataset.save_to_disk(f"{script_args.data_dir}/{save_name}-no-reject/test")


if __name__ == "__main__":
    main()
