from datasets import load_dataset, load_from_disk, concatenate_datasets
from datasets.features import Audio

from src.utils.utils import CACHE_DIR, LIBRITTS_CONFIG, DATASET_BASE_DIR


def get_audio_dataset(dataset_path: str = None, dataset_name: str = None, split: str = 'train', keep_columns = [], sampling_rate: int = 16_000) -> str:
    """Get audio dataset from dataset name
    
    The processed dataset should have two columns: audio and text (corresponding transcription)"""

    columns_to_keep = ['audio', 'text'] + keep_columns

    if dataset_name is None:
        assert dataset_path is not None, "Either dataset_path or dataset_name must be provided"
        if "EuroSpeech" in dataset_path or "fleurs" in dataset_path:
            dataset_name = dataset_path.split('/')[-3]
        else:
            if dataset_path.endswith(split):
                dataset_name = dataset_path.split('/')[-2]
            else:
                dataset_name = dataset_path.split('/')[-1]

    if dataset_name == "VoiceAssistant-400K":
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/VoiceAssistant-400K"
        audio_dataset = load_dataset(dataset_path, split=split, cache_dir=CACHE_DIR)
        audio_dataset = audio_dataset.rename_columns({"question_audio": "audio", "question": "text"})
        audio_dataset
    elif dataset_name == "libritts_r_filtered":
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/libritts_r_filtered"
        metadata_path = f'{DATASET_BASE_DIR}/libritts-r-filtered-speaker-descriptions'

        temp_datasets = []
        for config in LIBRITTS_CONFIG[split]:
            for sub_split in LIBRITTS_CONFIG[split][config]:
                dataset = load_dataset(f'{DATASET_BASE_DIR}/libritts_r_filtered', config, split=sub_split, cache_dir=CACHE_DIR)
                metadata = load_dataset(metadata_path, config, split=sub_split, cache_dir=CACHE_DIR)

                # remove the original text column
                if 'text' in dataset.column_names:
                    dataset = dataset.remove_columns('text')

                # the metadata and dataset should have the same order
                metadata_columns_to_remove = set(metadata.column_names).intersection(set(dataset.column_names))
                metadata = metadata.remove_columns(metadata_columns_to_remove)

                dataset = concatenate_datasets([dataset, metadata], axis=1)
                dataset = dataset.remove_columns(set(dataset.column_names).difference(set(columns_to_keep)))

                temp_datasets.append(dataset)

        audio_dataset = concatenate_datasets(temp_datasets, axis=0)
    elif dataset_name == "mls_eng_10k" in dataset_name:
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/mls_eng_10k"
        metadata_path = f"{DATASET_BASE_DIR}/mls-eng-10k-tags_tagged_10k_generated"

        dataset = load_dataset(dataset_path, split=split, cache_dir=CACHE_DIR)
        metadata = load_dataset(metadata_path, split=split, cache_dir=CACHE_DIR)

        # remove the original text column
        if 'text' in dataset.column_names:
            dataset = dataset.remove_columns('text')

        # the metadata and dataset should have the same order
        metadata_columns_to_remove = set(metadata.column_names).intersection(set(dataset.column_names))
        metadata = metadata.remove_columns(metadata_columns_to_remove)

        dataset = concatenate_datasets([dataset, metadata], axis=1)
        audio_dataset = dataset.remove_columns(set(dataset.column_names).difference(set(columns_to_keep)))
    elif dataset_name == "spoken-web-questions":
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/spoken-web-questions/train"
        audio_dataset = load_from_disk(dataset_path)
    elif dataset_name in ["common_voice_21_0", "DeepDialogue-orpheus", "DeepDialogue-xtts", "word_less_frequent", "mmsu_openbookqa", "openbookqa", ]:
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/{dataset_name}/{split}"
        audio_dataset = load_from_disk(dataset_path)
    elif dataset_name == "EuroSpeech":
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/EuroSpeech/uk/{split}"
        audio_dataset = load_from_disk(dataset_path)
        audio_dataset = audio_dataset.rename_columns({"human_transcript": "text"})    # should we use human transcript or asr transcript?
    elif dataset_name == "fleurs":
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/fleurs/en_us/{split}"
        audio_dataset = load_from_disk(dataset_path)
    elif dataset_name in ["audioset_strong", "audioset_wavcaps", "Nonspeech7K", "VocalSound", ]:
        # general audio datasets
        if dataset_path is None:
            dataset_path = f"{DATASET_BASE_DIR}/{dataset_name}/{split}"
        audio_dataset = load_from_disk(dataset_path)
        
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

    columns_to_remove = [col for col in audio_dataset.column_names if col not in columns_to_keep]
    audio_dataset = audio_dataset.remove_columns(columns_to_remove)

    # cast audio format
    audio_dataset = audio_dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))

    return audio_dataset
