import os.path as osp

import torch

from .base import BaseDataset
from .registry import DATASETS


@DATASETS.register_module()
class AudioDataset(BaseDataset):
    """Audio dataset for video recognition. Extracts the audio feature on-the-
    fly. Annotation file can be that of the rawframe dataset, or:

    .. code-block:: txt

        some/directory-1.wav 163 1
        some/directory-2.wav 122 1
        some/directory-3.wav 258 2
        some/directory-4.wav 234 2
        some/directory-5.wav 295 3
        some/directory-6.wav 121 3

    Args:
        ann_file (str): Path to the annotation file.
        pipeline (list[dict | callable]): A sequence of data transforms.
        suffix (str): The suffix of the audio file. Default: '.wav'.
        kwargs (dict): Other keyword args for `BaseDataset`.
    """

    def __init__(self, ann_file, pipeline, suffix='.wav', **kwargs):
        self.suffix = suffix
        super().__init__(ann_file, pipeline, modality='Audio', **kwargs)

    def load_annotations(self):
        """Load annotation file to get video information."""
        if self.ann_file.endswith('.json'):
            return self.load_json_annotations()
        video_infos = []
        with open(self.ann_file, 'r') as fin:
            for line in fin:
                line_split = line.strip().split()
                video_info = {}
                idx = 0
                filename = line_split[idx]
                if self.data_prefix is not None:
                    if not filename.endswith(self.suffix):
                        filename = osp.join(self.data_prefix,
                                            filename + self.suffix)
                    else:
                        filename = osp.join(self.data_prefix, filename)
                video_info['audio_path'] = filename
                idx += 1
                # idx for total_frames
                video_info['total_frames'] = int(line_split[idx])
                idx += 1
                # idx for label[s]
                label = [int(x) for x in line_split[idx:]]
                assert label, f'missing label in line: {line}'
                if self.multi_class:
                    assert self.num_classes is not None
                    onehot = torch.zeros(self.num_classes)
                    onehot[label] = 1.0
                    video_info['label'] = onehot
                else:
                    assert len(label) == 1
                    video_info['label'] = label[0]
                video_infos.append(video_info)

        return video_infos
