# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# AST: https://github.com/YuanGongND/ast
# --------------------------------------------------------
import csv, os, sys
import json
import time
import torchaudio
import numpy as np
import torch
import torch.nn.functional
from torch.utils.data import Dataset, Sampler
from torch.utils.data import DistributedSampler, WeightedRandomSampler
import torch.distributed as dist
import random
import math
import boto3 
import io
import ray
from torchlibrosa.stft import Spectrogram, LogmelFilterBank


class DistributedSamplerWrapper(DistributedSampler):
    def __init__(
            self, sampler, dataset,
            num_replicas=None,
            rank=None,
            shuffle: bool = True):
        super(DistributedSamplerWrapper, self).__init__(
            dataset, num_replicas, rank, shuffle)
        # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
        self.sampler = sampler

    def __iter__(self):
        if self.sampler.generator is None:
            self.sampler.generator = torch.Generator()
        self.sampler.generator.manual_seed(self.seed + self.epoch)
        indices = list(self.sampler)
        if self.epoch == 0:
            print(f"\n DistributedSamplerWrapper :  {indices[:10]} \n\n")
        indices = indices[self.rank:self.total_size:self.num_replicas]
        return iter(indices)


class DistributedWeightedSampler(Sampler):
    #dataset_train, samples_weight,  num_replicas=num_tasks, rank=global_rank
    def __init__(self, dataset, weights, num_replicas=None, rank=None, replacement=True, shuffle=True):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.replacement = replacement
        self.weights = torch.from_numpy(weights)
        self.shuffle = shuffle

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        # # get targets (you can alternatively pass them in __init__, if this op is expensive)
        # targets = self.dataset.targets
        # # select only the wanted targets for this subsample
        # targets = torch.tensor(targets)[indices]
        # assert len(targets) == self.num_samples
        # # randomly sample this subset, producing balanced classes
        # weights = self.calculate_weights(targets)
        weights = self.weights[indices]

        subsample_balanced_indicies = torch.multinomial(weights, self.num_samples, self.replacement)
        # now map these target indicies back to the original dataset index...
        dataset_indices = torch.tensor(indices)[subsample_balanced_indicies]
        return iter(dataset_indices.tolist())

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


def make_index_dict(label_csv):
    index_lookup = {}
    with open(label_csv, 'r') as f:
        csv_reader = csv.DictReader(f)
        line_count = 0
        for row in csv_reader:
            index_lookup[row['mid']] = row['index']
            line_count += 1
    return index_lookup


def make_name_dict(label_csv):
    name_lookup = {}
    with open(label_csv, 'r') as f:
        csv_reader = csv.DictReader(f)
        line_count = 0
        for row in csv_reader:
            name_lookup[row['index']] = row['display_name']
            line_count += 1
    return name_lookup


def lookup_list(index_list, label_csv):
    label_list = []
    table = make_name_dict(label_csv)
    for item in index_list:
        label_list.append(table[item])
    return label_list


class AudiosetDataset(Dataset):
    def __init__(self, dataset_json_file, audio_conf, label_csv="dataset/audioset_util/class_labels_indices.csv", 
                 use_fbank=False, fbank_dir=None, target_sample_rate=16000,
                 output_form="kaldi", load_video=False, mode='train', pad_to_seconds=None,
                 **kwargs):
        """
        Dataset that manages audio recordings
        :param audio_conf: Dictionary containing the audio loading and preprocessing settings
        :param dataset_json_file
        """
        self.datapath = dataset_json_file
        with open(dataset_json_file, 'r') as fp:
            data_json = json.load(fp)
        self.use_fbank = use_fbank
        self.fbank_dir = fbank_dir
        self.target_sample_rate = target_sample_rate
        self.output_form = output_form
        self.data = data_json['data']
        self.audio_conf = audio_conf
        print('---------------the {:s} dataloader---------------'.format(self.audio_conf.get('mode')))
        if 'multilabel' in self.audio_conf.keys():
            self.multilabel = self.audio_conf['multilabel']
        else:
            self.multilabel = False
        print(f'multilabel: {self.multilabel}')
        self.melbins = self.audio_conf.get('num_mel_bins')
        self.freqm = self.audio_conf.get('freqm')
        self.timem = self.audio_conf.get('timem')
        print('using following mask: {:d} freq, {:d} time'.format(self.audio_conf.get('freqm'), self.audio_conf.get('timem')))
        self.mixup = self.audio_conf.get('mixup')
        print('using mix-up with rate {:f}'.format(self.mixup))
        self.dataset = self.audio_conf.get('dataset')
        self.norm_mean = self.audio_conf.get('mean')
        self.norm_std = self.audio_conf.get('std')
        print('Dataset: {}, mean {:.3f} and std {:.3f}'.format(self.dataset, self.norm_mean, self.norm_std))
        self.noise = self.audio_conf.get('noise')
        if self.noise == True:
            print('now use noise augmentation')
        self.index_dict = make_index_dict(label_csv)
        self.label_num = len(self.index_dict)
        self.roll_mag_aug = self.audio_conf.get('roll_mag_aug', False)
        print(f'number of classes: {self.label_num}')
        print(f'size of dataset {self.__len__()}')
        print(f'roll_mag_aug {self.roll_mag_aug}')
        # S3 conf (lazy client per process)
        self._s3_conf = {
            "aws_access_key_id": os.environ.get("S3_KEY"),
            "aws_secret_access_key": os.environ.get("S3_SECRET"),
            "endpoint_url": os.environ.get("S3_ENDPOINT_URL"),
        }
        self._s3 = None
        self.pad_to_seconds = pad_to_seconds
        # Masking
        self.freq_mask = torchaudio.transforms.FrequencyMasking(self.freqm)
        self.time_mask = torchaudio.transforms.TimeMasking(self.timem)
        
        
    def _client(self):
        if self._s3 is None:
            self._s3 = boto3.client(
                "s3",
                aws_access_key_id=self._s3_conf["aws_access_key_id"],
                aws_secret_access_key=self._s3_conf["aws_secret_access_key"],
                endpoint_url=self._s3_conf["endpoint_url"],
            )
        return self._s3
    
    def __getstate__(self):
        d = self.__dict__.copy()
        d.pop("_s3", None)        # boto3 client is not picklable
        return d

    def __setstate__(self, d):
        self.__dict__.update(d)
        self._s3 = None

    def _roll_mag_aug(self, waveform):
        if not self.roll_mag_aug:
            return waveform
        waveform = waveform.squeeze().numpy()
        idx=np.random.randint(len(waveform))
        rolled_waveform=np.roll(waveform,idx)
        mag = np.random.beta(10, 10) + 0.5
        return torch.Tensor(rolled_waveform*mag).unsqueeze(0)
    
    def _load_file(self, file_path):
        if "s3" in file_path:
            bucket = file_path.split('/')[0]
            key = '/'.join(file_path.split('/')[1:])
            obj = self._client().get_object(Bucket=bucket, Key=key)
            body = obj["Body"].read()
            data, samplerate = torchaudio.load(io.BytesIO(body))
            obj["Body"].close()
        else:
            data, samplerate = torchaudio.load(file_path)
        if samplerate != self.target_sample_rate:
            data = torchaudio.transforms.Resample(orig_freq=samplerate, new_freq=self.target_sample_rate)(data)
            samplerate = self.target_sample_rate
        if data.shape[0] > 1:
            data = torch.mean(data, dim=0, keepdim=True)
        if self.pad_to_seconds is not None:
            if data.shape[-1] > 64600:
                data = data[..., :64600]
        return data, samplerate
    
    def _to_output_form(self, waveform, sample_rate):
        if self.output_form == "kaldi":
            if waveform.dim() == 1:
                waveform = waveform.unsqueeze(0)
            fbank = torchaudio.compliance.kaldi.fbank(
                waveform, htk_compat=True, sample_frequency=sample_rate, use_energy=False,
                window_type='hanning', num_mel_bins=self.melbins, dither=0.0, frame_shift=10
            )
            return fbank
        elif self.output_form == "raw":
            return waveform
        elif self.output_form == "mel_spectrogram":
            window = 'hann'
            pad_mode = 'reflect'
            ref = 1.0
            amin = 1e-10
            top_db = None
            spec_window_size = 512
            hop_size = 160
            fmin = 0
            fmax = 7800
            center = True
            # Spectrogram extractor
            spectrogram_extractor = Spectrogram(n_fft=spec_window_size, hop_length=hop_size, 
                win_length=spec_window_size, window=window, center=center, pad_mode=pad_mode, 
                freeze_parameters=True)
            # Logmel feature extractor
            logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=spec_window_size, 
                n_mels=self.melbins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 
                freeze_parameters=True)
                
            waveform = waveform.unsqueeze(0)  # (1, n_samples)
            out = logmel_extractor(spectrogram_extractor(waveform))
            out = out.squeeze(0).squeeze(0)
            return out
        else:
            raise ValueError(f"Unknown output form: {self.output_form}")

    def _wav2fbank(self, filename, filename2=None):
        if filename2 == None:
            waveform, sr = self._load_file(filename)
            waveform = waveform - waveform.mean()
            if self.roll_mag_aug:
                waveform = self._roll_mag_aug(waveform)
        # mixup
        else:
            waveform1, sr = self._load_file(filename)
            waveform2, _ = self._load_file(filename2)

            waveform1 = waveform1 - waveform1.mean()
            waveform2 = waveform2 - waveform2.mean()

            if self.roll_mag_aug:
                waveform1 = self._roll_mag_aug(waveform1)
                waveform2 = self._roll_mag_aug(waveform2)

            if waveform1.shape[1] != waveform2.shape[1]:
                if waveform1.shape[1] > waveform2.shape[1]:
                    # padding
                    temp_wav = torch.zeros(1, waveform1.shape[1])
                    temp_wav[0, 0:waveform2.shape[1]] = waveform2
                    waveform2 = temp_wav
                else:
                    # cutting
                    waveform2 = waveform2[0, 0:waveform1.shape[1]]

            # sample lambda from beta distribtion
            mix_lambda = np.random.beta(10, 10)

            mix_waveform = mix_lambda * waveform1 + (1 - mix_lambda) * waveform2
            waveform = mix_waveform - mix_waveform.mean()

        out = self._to_output_form(waveform, sr)
        target_length = self.audio_conf.get('target_length')
        n_frames = out.shape[0]

        p = target_length - n_frames

        # cut and pad
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            out = m(out)
        elif p < 0:
            out = out[0:target_length, :]

        if filename2 == None:
            return out, 0
        else:
            return out, mix_lambda


    def _fbank(self, filename, filename2=None):
        if filename2 == None:
            fn1 = os.path.join(self.fbank_dir, os.path.basename(filename).replace('.wav','.npy'))
            fbank = np.load(fn1)
            return torch.from_numpy(fbank), 0
        else:
            fn1 = os.path.join(self.fbank_dir, os.path.basename(filename).replace('.wav','.npy'))
            fn2 = os.path.join(self.fbank_dir, os.path.basename(filename2).replace('.wav','.npy'))
            # sample lambda from beta distribtion
            mix_lambda = np.random.beta(10, 10)
            fbank = mix_lambda * np.load(fn1) + (1-mix_lambda) * np.load(fn2)  
            return torch.from_numpy(fbank), mix_lambda

    def __getitem__(self, index):
        """
        returns: image, audio, nframes
        where image is a FloatTensor of size (3, H, W)
        audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
        nframes is an integer
        """
        # do mix-up for this sample (controlled by the given mixup rate)
        if random.random() < self.mixup: # for audio_exp, when using mixup, assume multilabel
            datum = self.data[index]
            # find another sample to mix, also do balance sampling
            # sample the other sample from the multinomial distribution, will make the performance worse
            # mix_sample_idx = np.random.choice(len(self.data), p=self.sample_weight_file)
            # sample the other sample from the uniform distribution
            mix_sample_idx = random.randint(0, len(self.data)-1)
            mix_datum = self.data[mix_sample_idx]

            # get the mixed fbank
            if not self.use_fbank:
                fbank, mix_lambda = self._wav2fbank(datum['wav'], mix_datum['wav'])
            else:
                fbank, mix_lambda = self._fbank(datum['wav'], mix_datum['wav'])
            # initialize the label
            label_indices = np.zeros(self.label_num)
            # add sample 1 labels
            for label_str in datum['labels'].split(','):
                label_indices[int(self.index_dict[label_str])] += mix_lambda
            # add sample 2 labels
            for label_str in mix_datum['labels'].split(','):
                label_indices[int(self.index_dict[label_str])] += 1.0-mix_lambda
            label_indices = torch.FloatTensor(label_indices)
        # if not do mixup
        else:
            datum = self.data[index]
            label_indices = np.zeros(self.label_num)
            if not self.use_fbank:
                fbank, mix_lambda = self._wav2fbank(datum['wav'])
            else:
                fbank, mix_lambda = self._fbank(datum['wav'])
            for label_str in datum['labels'].split(','):
                label_indices[int(self.index_dict[label_str])] = 1.0

            if self.multilabel:
                label_indices = torch.FloatTensor(label_indices)
            else:
                # remark : for ft cross-ent
                label_indices = int(self.index_dict[label_str])
        # SpecAug for training (not for eval)
        fbank = fbank.transpose(0,1).unsqueeze(0) # 1, 128, 1024 (...,freq,time)
        if self.freqm != 0:
            fbank = self.freq_mask(fbank)
        if self.timem != 0:
            fbank = self.time_mask(fbank) # (..., freq, time)
        fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq
        fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
        if self.noise == True: # default is false, true for spc
            fbank = fbank + torch.rand(fbank.shape[0], fbank.shape[1]) * np.random.rand() / 10
            fbank = torch.roll(fbank, np.random.randint(-10, 10), 0)
        # the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128]
        return fbank.unsqueeze(0), label_indices

    def __len__(self):
        return len(self.data)


class AudiosetCLRDataset(AudiosetDataset):
    def __init__(self, dataset_json_file, audio_conf, label_csv="dataset/audioset_util/class_labels_indices.csv", 
                 use_fbank=False, fbank_dir=None, target_sample_rate=16000,
                 load_video=False, mode='train', output_form="kaldi", **kwargs):
        super(AudiosetCLRDataset, self).__init__(
            dataset_json_file=dataset_json_file, audio_conf=audio_conf, label_csv=label_csv, use_fbank=use_fbank, fbank_dir=fbank_dir, 
            load_video=load_video, mode=mode, output_form=output_form, **kwargs)
        self.target_sample_rate = target_sample_rate
        if "asymmetric_augment" in self.audio_conf.keys():
            self.clr_augment = self.audio_conf.get('asymmetric_augment')
        else:
            self.clr_augment = self.audio_conf.get('clr_augment')
        print('using CLR augment for pretraining')

    def _pad(self, out):
        target_length = self.audio_conf.get('target_length')
        n_frames = out.shape[0]
        p = target_length - n_frames
        # cut and pad
        if p > 0:
            m = torch.nn.ZeroPad2d((0, 0, 0, p))
            out = m(out)
        elif p < 0:
            out = out[0:target_length, :]
        return out

    def _roll_mag_aug(self, waveform):
        if not self.roll_mag_aug:
            return waveform
        waveform = waveform.squeeze().numpy()
        idx=np.random.randint(len(waveform))
        rolled_waveform=np.roll(waveform,idx)
        mag = np.random.beta(10, 10) + 0.5
        rolled_waveform = np.expand_dims(rolled_waveform, axis=0)
        return rolled_waveform*mag
    
    def __getitem__(self, index):
        """
        returns: image, audio, nframes
        where image is a FloatTensor of size (3, H, W)
        audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform
        nframes is an integer
        """
        datum = self.data[index]
        label_indices = np.zeros(self.label_num)

        counter = 0
        waveform = None
        while counter < 10:
            try:
                waveform, sr = self._load_file(datum['wav'])
                break
            except Exception as e:
                print(f"Error loading file {datum['wav']}: {e}. Retrying...")
                time.sleep(1)
                counter += 1
    
        if waveform is None:
            raise RuntimeError(f"Failed to load file after 10 retries: {datum['wav']}")
        
        if type(self.clr_augment) is tuple:
            x_0 = self.clr_augment[0](self._roll_mag_aug(waveform).squeeze())
            x_1 = self.clr_augment[1](self._roll_mag_aug(waveform).squeeze())
        else:
            x_0 = self.clr_augment(self._roll_mag_aug(waveform).squeeze())
            x_1 = self.clr_augment(self._roll_mag_aug(waveform).squeeze())
        x_0 = x_0 - x_0.mean()
        x_1 = x_1 - x_1.mean()
        x_0 = torch.from_numpy(x_0)
        x_1 = torch.from_numpy(x_1)
        out_0 = self._to_output_form(x_0, sr)
        out_1 = self._to_output_form(x_1, sr)
        out_0 = self._pad(out_0)
        out_1 = self._pad(out_1)
        out_0 = (out_0 - self.norm_mean) / (self.norm_std * 2)
        out_1 = (out_1 - self.norm_mean) / (self.norm_std * 2)

        return out_0.unsqueeze(0), out_1.unsqueeze(0), label_indices

