from typing import Optional, Dict

from ..constant import PROMPT_BEGIN, PROMPT_USER, PROMPT_END, SANITY_CHECK_DATASIZE

from datasets import load_dataset, load_from_disk, Dataset


def pku_format(question, system=""):
    prompt = system
    prompt += PROMPT_BEGIN
    prompt += PROMPT_USER.format(input=question)
    prompt += PROMPT_END
    return prompt


def get_pku_by_helpfulness(
    split: str, sanity_check: bool = False, cache_dir: Optional[str] = None
) -> Dataset:
    # Check if dataset already exists in cache
    try:
        dataset = load_from_disk(f"{cache_dir}/{split}")
        print(f"Loaded {split} dataset from cache.")
    except FileNotFoundError:
        print(f"No cached {split} dataset found, loading and processing...")
        dataset = load_dataset(
            "PKU-Alignment/PKU-SafeRLHF-30K", split=split, cache_dir=cache_dir
        )

        def get_chosen(sample) -> Dict[str, str]:
            return {
                'raw_prompt': sample["prompt"],
                "salad_prompt": sample["prompt"],
                "prompt": pku_format(sample["prompt"]),
                "chosen": sample["response_0"] if sample["better_response_id"] == 0 else sample["response_1"],
                "rejected": sample["response_1"] if sample["better_response_id"] == 0 else sample["response_0"],
            }

        remove_columns = [
            'response_0', 'response_1', 'is_response_0_safe', 'is_response_1_safe', 'better_response_id', 'safer_response_id'
        ]
        dataset = dataset.map(get_chosen, remove_columns=remove_columns)
        # Save the processed dataset to disk
        dataset.save_to_disk(f"{cache_dir}/{split}")

    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), SANITY_CHECK_DATASIZE)))

    return dataset


def get_pku_by_safety(
    split: str, sanity_check: bool = False, cache_dir: Optional[str] = None,
) -> Dataset:
    # Check if dataset already exists in cache
    try:
        dataset = load_from_disk(f"{cache_dir}/{split}")
        print(f"Loaded {split} dataset from cache.")
    except FileNotFoundError:
        print(f"No cached {split} dataset found, loading and processing...")
        dataset = load_dataset(
            "PKU-Alignment/PKU-SafeRLHF-30K", split=split, cache_dir=cache_dir
        )

        def get_chosen(sample) -> Dict[str, str]:
            return {
                'raw_prompt': sample["prompt"],
                'salad_prompt': sample["prompt"],
                'prompt': pku_format(sample["prompt"]),
                "chosen": sample["response_0"] if sample["safer_response_id"] == 0 else sample["response_1"],
                "rejected": sample["response_1"] if sample["safer_response_id"] == 0 else sample["response_0"],
            }

        remove_columns = [
            'response_0', 'response_1', 'is_response_0_safe', 'is_response_1_safe', 'better_response_id', 'safer_response_id'
        ]
        dataset = dataset.map(get_chosen, remove_columns=remove_columns)
        # Save the processed dataset to disk
        dataset.save_to_disk(f"{cache_dir}/{split}")

    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), SANITY_CHECK_DATASIZE)))

    return dataset
