from copy import deepcopy
import os

from hydra.utils import instantiate
from pytorch_lightning import LightningDataModule
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Lambda

from .dataset_libri import ADatasetLibri
from .samplers import ByFrameCountSampler, DistributedSamplerWrapper, RandomSamplerWrapper
from .transforms import AdaptiveLengthTimeMask, AddNoise


def pad(samples, pad_val=0.0):
    if len(samples) == 0:
        return None, None


    lengths = [len(s) for s in samples]
    max_size = max(lengths)
    sample_shape = list(samples[0].shape[1:])
    collated_batch = samples[0].new_zeros([len(samples), max_size] + sample_shape)
    for i, sample in enumerate(samples):
        diff = len(sample) - max_size
        if diff == 0:
            collated_batch[i] = sample
        else:
            collated_batch[i] = torch.cat(
                [sample, sample.new_full([-diff] + sample_shape, pad_val)]
            )
    if len(samples[0].shape) < 3:
        collated_batch = collated_batch.unsqueeze(1)
    else:
        collated_batch = collated_batch.permute((0, 4, 1, 2, 3)) # [B, T, H, W, C] -> [B, C, T, H, W]
    return collated_batch, lengths


def collate_pad(batch):
    batch_out = {}
    for data_type in ('audio', 'audio_aug'):
        pad_val = -1 if data_type == 'label' else 0.0
        c_batch, sample_lengths = pad([s[data_type] for s in batch if s[data_type] is not None], pad_val)
        batch_out[data_type] = c_batch
        batch_out[data_type + '_lengths'] = sample_lengths
    
    batch_out["label"] = [s["label"] for s in batch if s["label"] is not None]
        
    return batch_out


class DataModule(LightningDataModule):

    def __init__(self, cfg=None):
        super().__init__()
        self.cfg = cfg
        self.total_gpus = self.cfg.gpus * self.cfg.trainer.num_nodes
        print('total gpus:', self.total_gpus)

    def _raw_audio_transform(self, mode):
        args = self.cfg.data
        transform = [Lambda(lambda x: x)]
        transform_aug = deepcopy(transform)
        if mode == "train":
            transform_aug.append(
                AdaptiveLengthTimeMask(
                    window=int(args.timemask_window_audio * 16_000),
                    stride=int(args.timemask_stride_audio * 16_000),
                    replace_with_zero=True
                )
            )

        return Compose(transform), Compose(transform_aug)

    def _dataloader(self, ds, sampler, collate_fn):
        return DataLoader(
            ds,
            num_workers=self.cfg.num_workers,
            pin_memory=True,
            batch_sampler=sampler,
            collate_fn=collate_fn,
        )

    def test_dataloader(self):
        ds_args = self.cfg.data.dataset

        transform_audio, transform_audio_aug = self._raw_audio_transform(mode='test')

        test_ds = ADatasetLibri(
            data_path=ds_args.test_csv,
            audio_path_prefix_libri=self.cfg.data.libri_audio_dir,
            transforms={'audio': transform_audio, 'audio_aug': transform_audio_aug},
        )
        sampler = ByFrameCountSampler(test_ds, self.cfg.data.frames_per_gpu_val, shuffle=False)
        if self.total_gpus > 1:
            sampler = DistributedSamplerWrapper(sampler, shuffle=False, drop_last=True)
        return self._dataloader(test_ds, sampler, collate_pad)
