from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import torch
import librosa


class AudioDataset(Dataset):
    def __init__(
        self,
        filepaths,
        labels,
        skip_times=None,
        num_classes=1,
        normalize="std",
        max_len=32000,
        random_sampling=True,
        train=False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.filepaths = filepaths
        self.labels = labels
        self.skip_times = skip_times
        self.num_classes = num_classes
        self.random_sampling = random_sampling
        self.normalize = normalize
        self.max_len = max_len
        self.train = train
        if not self.train:
            assert (
                not self.random_sampling
            ), "Ensure random_sampling is disabled for val"

    def __len__(self):
        return len(self.filepaths)

    def crop_or_pad(self, audio, max_len, random_sampling=True):
        audio_len = audio.shape[0]
        if random_sampling:
            diff_len = abs(max_len - audio_len)
            if audio_len < max_len:
                pad1 = np.random.randint(0, diff_len)
                pad2 = diff_len - pad1
                audio = np.pad(audio, (pad1, pad2), mode="constant")
            elif audio_len > max_len:
                idx = np.random.randint(0, diff_len)
                audio = audio[idx : (idx + max_len)]
        else:
            if audio_len < max_len:
                audio = np.pad(audio, (0, max_len - audio_len), mode="constant")
            elif audio_len > max_len:
                # Crop from the beginning
                # audio = audio[:max_len]

                # Crop from 3/4 of the audio
                # eq: l = (3x + t + x) => idx = 3x = (l - t) / 4 * 3
                idx = int((audio_len - max_len) / 4 * 3)
                audio = audio[idx : (idx + max_len)]
        return audio

    def __getitem__(self, idx):
        # Load audio
        audio, sr = librosa.load(self.filepaths[idx], sr=None)
        target = np.array([self.labels[idx]])

        # Trim start of audio (torchaudio.transforms.vad)
        if self.skip_times is not None:
            skip_time = self.skip_times[idx]
            audio = audio[int(skip_time*sr):]

        # Ensure fixed length
        audio = self.crop_or_pad(audio, self.max_len, self.random_sampling)

        if self.normalize == "std":
            audio /= np.maximum(np.std(audio), 1e-6)
        elif self.normalize == "minmax":
            audio -= np.min(audio)
            audio /= np.maximum(np.max(audio), 1e-6)

        audio = torch.from_numpy(audio).float()
        target = torch.from_numpy(target).float().squeeze()
        return {
            "audio": audio,
            "target": target,
        }


def get_dataloader(
    filepaths,
    labels,
    skip_times=None,
    batch_size=8,
    num_classes=1,
    max_len=32000,
    random_sampling=True,
    normalize="std",
    train=False,
    # drop_last=False,
    pin_memory=True,
    worker_init_fn=None,
    collate_fn=None,
    num_workers=0,
    distributed=False,
):
    dataset = AudioDataset(
        filepaths,
        labels,
        skip_times=skip_times,
        num_classes=num_classes,
        max_len=max_len,
        random_sampling=random_sampling,
        normalize=normalize,
        train=train,
    )

    if distributed:
        # drop_last is set to True to validate properly
        # Ref: https://discuss.pytorch.org/t/how-do-i-validate-with-pytorch-distributeddataparallel/172269/8
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, shuffle=train, drop_last=not train
        )
    else:
        sampler = None

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(sampler is None) and train,
        # drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=pin_memory,
        worker_init_fn=worker_init_fn,
        collate_fn=collate_fn,
        sampler=sampler,
    )
    return dataloader
