import os

import cv2
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset


class ADatasetLibri(Dataset):
    def __init__(
            self, 
            data_path,
            audio_path_prefix_libri,
            transforms=None,
            skip_fails=True,
        ):

        self.data_path = data_path
        self.audio_path_prefix_librispeech = audio_path_prefix_libri
        self.transforms = transforms

        self.paths_counts_labels = self.configure_files()
        self.num_fails = 0

        self.skip_fails = skip_fails
    
    def configure_files(self):
        # from https://github.com/facebookresearch/pytorchvideo/blob/874d27cb55b9d7e9df6cd0881e2d7fe9f262532b/pytorchvideo/data/labeled_video_paths.py#L37
        paths_counts_labels = []
        with open(self.data_path, "r") as f:
            for path_count_label in f.read().splitlines():
                tag, file_path, count, label = path_count_label.split(",")
                paths_counts_labels.append((tag, file_path, int(count), label))
        return paths_counts_labels
    
    def load_audio(self, path):
        audio, sr = torchaudio.load(path, normalize=True)
        # assert sr == 16_000
        return audio
        
    def __len__(self):
        return len(self.paths_counts_labels)

    def __getitem__(self, index):
        tag, file_path, count, label = self.paths_counts_labels[index]
        self.audio_path_prefix = getattr(self, f"audio_path_prefix_{tag}", "")
        
        audio = self.load_audio(os.path.join(self.audio_path_prefix, file_path))
        audio_clean = self.transforms['audio'](audio).squeeze(0)
        audio_aug = self.transforms['audio_aug'](audio).squeeze(0)

        return {'audio': audio_clean, 'audio_aug': audio_aug, 'label': label}
