from .base import *
from .harry_potter import *
from .moth_radio_hour import *
from .dataset_utils import *

def get_dataset(dataset_name, data_dir, *args, **kwargs):
    """Factory method for creating datasets given the `dataset_name` as well as
    its data directory and other important keyword arguments, specified by each
    dataset.
    """
    if dataset_name == "HarryPotter":
        kwargs["remove_format_chars"] = kwargs.get("remove_format_chars", False)
        kwargs["remove_punc_spacing"] = kwargs.get("remove_punc_spacing", False)
        return HarryPotter(data_dir, *args, **kwargs)
    elif dataset_name == "MothRadioHour":
        kwargs["remove_format_chars"] = kwargs.get("remove_format_chars", True)
        kwargs["remove_punc_spacing"] = kwargs.get("remove_punc_spacing", True)
        return MothRadioHour(data_dir, *args, **kwargs)
    else:
        raise ValueError(f"Invalid dataset name: {dataset_name}")
