import torch
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data, target, transform=None, target_transform=None, is_add_noise=False):
        self.transform = transform
        self.target_transform = target_transform

        self.data = data
        self.targets = target

        self.label_type = 'soft-noise'
        self.is_add_noise = is_add_noise

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):

        sample_data = self.data[idx]
        label = self.targets[idx]
        sample_data = torch.Tensor(sample_data)
        if self.transform is not None:
            sample_data = self.transform(sample_data)

        if self.target_transform is not None:
            label = self.target_transform(label)
        if self.is_add_noise:
            noised_data_l, noise_label_l = self.generate_multi_noisy_samples(sample_data)
            return sample_data, label, idx, noised_data_l, noise_label_l
        else:
            return sample_data, label, idx

    def generate_multi_noisy_samples(self, inputs):
        noised_data_l, noise_label_l = [], []
        # for noise_p in [0.5, 0.8, 1]:
        for noise_p in [1]:
            noised_data, noise_label = self.generate_random_element_wise_noisy_samples(inputs, noise_p)
            noised_data_l.append(noised_data)
            noise_label_l.append(noise_label)
        noised_data_l = torch.cat(noised_data_l)
        noise_label_l = torch.cat(noise_label_l)
        return noised_data_l, noise_label_l

    def generate_random_element_wise_noisy_samples(self, inputs, noise_portion):
        bs = 1
        n = inputs.shape[0]

        # define 3 levels of noise: [0, 0.4), [0.4, 0.8), [0.8, 1.2]
        n_noise_levels = 3
        noise_levels = [0, 0.4, 0.8, 1.2]
        noise_size = [n // n_noise_levels, n // n_noise_levels, n-(n_noise_levels-1)*(n//n_noise_levels)]

        # generate an ordered noise (so far the same size of each level)
        noise = []
        noise_label = []
        for i in range(n_noise_levels):
            level = torch.rand(1).item() * noise_levels[i+1] + noise_levels[i]
            l1 = torch.normal(0, level, (bs, noise_size[i]))
            noise.append(l1)

            if self.label_type == 'soft-dist':
                label = torch.ones_like(l1) * level  # soft-distribution label
            elif self.label_type == 'hard':
                label = torch.ones_like(l1)  # hard label
            elif self.label_type == 'soft-noise':
                label = torch.abs(l1)  # soft-noise label
            else:
                raise NotImplementedError
            noise_label.append(label)

        noise = torch.cat(noise, dim=1)
        noise_label = torch.cat(noise_label, dim=1)

        # make noise mask to control the noise portion
        n_noise = int(n * noise_portion)
        n_zeros = n - n_noise
        mask = torch.cat([torch.ones(bs, n_noise), torch.zeros(bs, n_zeros)], dim=1)
        shuffle = torch.argsort(torch.rand(bs, n))
        mask = mask.gather(1, shuffle)

        # randomly decide the position of noise
        shuffle = torch.argsort(torch.rand(bs, n))
        noise = noise.gather(1, shuffle) * mask
        noise_label = noise_label.gather(1, shuffle) * mask

        # if self.is_maxpool:
        #     noise_label = torch.max_pool1d(noise_label, n)
        # if self.is_scaled:
        #     self.std[self.std == 0] = 1
        #     noise = noise * torch.from_numpy(self.std).float()
        #     # noise_label = torch.abs(noise)

        noised_data = inputs + noise
        noised_data = noised_data[0]
        noise_label = noise_label[0]
        return noised_data, noise_label




