import os
from fairseq.data.audio import (
    AudioTransform,
    CompositeAudioTransform,
    import_transforms,
    register_audio_transform,
)


class AudioDatasetTransform(AudioTransform):
    pass


AUDIO_DATASET_TRANSFORM_REGISTRY = {}
AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()


def get_audio_dataset_transform(name):
    return AUDIO_DATASET_TRANSFORM_REGISTRY[name]


def register_audio_dataset_transform(name):
    return register_audio_transform(
        name,
        AudioDatasetTransform,
        AUDIO_DATASET_TRANSFORM_REGISTRY,
        AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
    )


import_transforms(os.path.dirname(__file__), "dataset")


class CompositeAudioDatasetTransform(CompositeAudioTransform):
    @classmethod
    def from_config_dict(cls, config=None):
        return super()._from_config_dict(
            cls,
            "dataset",
            get_audio_dataset_transform,
            CompositeAudioDatasetTransform,
            config,
            return_empty=True,
        )

    def get_transform(self, cls):
        for t in self.transforms:
            if isinstance(t, cls):
                return t
        return None

    def has_transform(self, cls):
        return self.get_transform(cls) is not None
