from typing import Tuple

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from trl.data_utils import apply_chat_template

from meta_alignment.config import TrainingConfig


def convert_dataset(dataset: Dataset, task: str, tokenizer: AutoTokenizer) -> Dataset:
    """
    Convert "Human: ... Assistant: ..." type dataset to
    {"prompt": {"content": ...}, ...} format.

    Args:
        dataset (Dataset): The input dataset with "chosen" and "rejected" fields.
    """

    def transform_example(example):
        match task:
            case "length":
                return apply_chat_template(example, tokenizer)
            case "hh":
                chosen = example["chosen"]
                prompt = chosen.split("Assistant:")[0].strip()
                return {"prompt": f"\n\n{prompt} \n\nAssistant:"}
            case "pku":
                prompt = example["prompt"]
                return {
                    "prompt": f"BEGINNING OF CONVERSATION: USER: {prompt} ASSISTANT:"
                }
            case _:
                raise ValueError(f"Unknown task: {task}")

    return dataset.map(transform_example, remove_columns=dataset.column_names)


def get_dataset(
    args: TrainingConfig, train_size: int = 2000, eval_size: int = 100
) -> Tuple[Dataset, Dataset]:
    match args.task:
        case "length":
            train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
            eval_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="test")
        case "hh":
            train_dataset = load_dataset("Anthropic/hh-rlhf", split="train")
            eval_dataset = load_dataset("Anthropic/hh-rlhf", split="test")
        case _:
            raise ValueError(f"Unknown task: {args.task}")
    if train_size >= 0:
        train_dataset = train_dataset.select(range(train_size))
    if eval_size >= 0:
        eval_dataset = eval_dataset.select(range(eval_size))
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
    train_dataset = convert_dataset(train_dataset, task=args.task, tokenizer=tokenizer)
    eval_dataset = convert_dataset(eval_dataset, task=args.task, tokenizer=tokenizer)
    return train_dataset, eval_dataset
