import sys
import wget
import numpy as np
from PIL import Image
import torchvision
from torch.utils.data.dataset import Subset
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances 
import torch
import torch.nn.functional as F
import random 
import json
import os
from zipfile import ZipFile
import torch.utils.data as data
import torchvision.transforms as transforms
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle


def get_cifar10(root, cfg_trainer, train=True,
                transform_train=None, transform_train_aug=None, transform_val=None,
                download=True, noise_file = ''):
    base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
    if train:
        train_idxs, val_idxs = train_val_split(base_dataset.targets)
        if cfg_trainer['real'] == 'clean' or cfg_trainer['real'] == 'worst' or cfg_trainer['real'] == 'aggre' \
                or cfg_trainer['real'] == 'rand1' or cfg_trainer['real'] == 'rand2' or cfg_trainer['real'] == 'rand3':
            train_cifar10_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            test_cifar10_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            noise_type_map = {'clean': 'clean_label', 'worst': 'worse_label', 'aggre': 'aggre_label',
                              'rand1': 'random_label1', 'rand2': 'random_label2', 'rand3': 'random_label3',
                              'clean100': 'clean_label', 'noisy100': 'noisy_label'}
            noise_type = noise_type_map[cfg_trainer['real']]
            is_human = True
            train_dataset = CIFAR10N(root='data_loader/data/',
                                    download=False,
                                    train=True,
                                    transform=train_cifar10_transform, transform_aug=transform_train_aug,
                                    noise_type=noise_type, is_human=is_human
                                    )
            val_dataset = CIFAR10N(root='data_loader/data/',
                                   download=False,
                                   train=False,
                                   transform=test_cifar10_transform,
                                   noise_type=noise_type
                                   )
            #new_targets = cifar10n(root, 'noisy_label', download)
            #train_dataset.train_labels = np.array(new_targets)[train_dataset.indexs]
        else:
            train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train, transform_aug=transform_train_aug)
            val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
            if cfg_trainer['asym']:
                train_dataset.asymmetric_noise()
                val_dataset.asymmetric_noise()
            elif cfg_trainer['instance']:
                train_dataset.instance_noise()
                val_dataset.instance_noise()
            elif cfg_trainer['custom_T_low']:
                train_dataset.custom_t_low_noise()
                val_dataset.custom_t_low_noise()
            elif cfg_trainer['custom_T_high']:
                train_dataset.custom_t_high_noise()
                val_dataset.custom_t_high_noise()
            else:
                train_dataset.symmetric_noise()
                val_dataset.symmetric_noise()
        print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}")  # Train: 45000 Val: 5000
    else:
        if cfg_trainer['real'] == 'clean' or cfg_trainer['real'] == 'worst' or cfg_trainer['real'] == 'aggre' \
                or cfg_trainer['real'] == 'rand1' or cfg_trainer['real'] == 'rand2' or cfg_trainer['real'] == 'rand3':
            test_cifar10_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            noise_type_map = {'clean': 'clean_label', 'worst': 'worse_label', 'aggre': 'aggre_label',
                              'rand1': 'random_label1', 'rand2': 'random_label2', 'rand3': 'random_label3',
                              'clean100': 'clean_label', 'noisy100': 'noisy_label'}
            noise_type = noise_type_map[cfg_trainer['real']]
            val_dataset = CIFAR10N(root='data_loader/data/',
                                   download=False,
                                   train=False,
                                   transform=test_cifar10_transform,
                                   noise_type=noise_type
                                   )
            train_dataset = []
        else:
            train_dataset = []
            val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
        print(f"Test: {len(val_dataset)}")

    return train_dataset, val_dataset


def download_cifarn(root):
    wget.download('http://ucsc-real.soe.ucsc.edu:1995/files/cifar-10-100n-main.zip', out=root)
    with ZipFile(os.path.join(root, 'cifar-10-100n-main.zip'), 'r') as f:
        f.extractall(root)


def cifar10n(root, key, download=True):
    label_path = os.path.join(root, 'cifar-10-100n-main/data', 'CIFAR-10_human.pt')
    if not os.path.exists(label_path):
        if download:
            download_cifarn(root)
    noise_label = torch.load(label_path)
    targets_new = torch.from_numpy(noise_label[key])
    return targets_new


def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
    num_classes = 10
    base_dataset = np.array(base_dataset)
    train_n = int(len(base_dataset) * 0.9 / num_classes)
    train_idxs = []
    val_idxs = []

    for i in range(num_classes):
        idxs = np.where(base_dataset == i)[0]
        np.random.shuffle(idxs)
        train_idxs.extend(idxs[:])
        val_idxs.extend(idxs[train_n:])
    np.random.shuffle(train_idxs)
    np.random.shuffle(val_idxs)

    return train_idxs, val_idxs


class CIFAR10_train(torchvision.datasets.CIFAR10):
    def __init__(self, root, cfg_trainer, indexs, train=True,
                 transform=None, transform_aug=None, target_transform=None,
                 download=False):
        super(CIFAR10_train, self).__init__(root, train=train,
                                            transform=transform,
                                            target_transform=target_transform,
                                            download=download)
        self.num_classes = 10
        self.cfg_trainer = cfg_trainer
        self.train_data = self.data[indexs]  # self.train_data[indexs]
        self.train_labels = np.array(self.targets)[indexs]  # np.array(self.train_labels)[indexs]
        self.indexs = indexs
        self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
        self.noise_indx = []
        self.transform_aug = transform_aug
        self.train_labels_gt = self.train_labels.copy()

    def symmetric_noise(self):
        indices = np.random.permutation(len(self.train_data))
        for i, idx in enumerate(indices):
            if i < self.cfg_trainer['percent'] * len(self.train_data):
                self.noise_indx.append(idx)
                self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)

    def asymmetric_noise(self):
        for i in range(self.num_classes):
            indices = np.where(self.train_labels == i)[0]
            np.random.shuffle(indices)
            for j, idx in enumerate(indices):
                if j < self.cfg_trainer['percent'] * len(indices):
                    self.noise_indx.append(idx)
                    # truck -> automobile
                    if i == 9:
                        self.train_labels[idx] = 1
                    # bird -> airplane
                    elif i == 2:
                        self.train_labels[idx] = 0
                    # cat -> dog
                    elif i == 3:
                        self.train_labels[idx] = 5
                    # dog -> cat
                    elif i == 5:
                        self.train_labels[idx] = 3
                    # deer -> horse
                    elif i == 4:
                        self.train_labels[idx] = 7

    def instance_noise(self):
        noise_label = torch.load(self.root + 'cifar-noisy/IDN_{:.1f}_C10.pt'.format(self.cfg_trainer['percent']))
        noisylabel = noise_label['noise_label_train'][self.indexs]
        # truelabel = noise_label['clean_label_train'][self.indexs]
        self.train_labels = np.array(noisylabel)

    def custom_t_low_noise(self):
        self.T = [[0.82, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.83, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.81, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.823, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.817, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.822, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.821, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.818, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.819, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.82]]
        classes = range(10)
        for i, cl in enumerate(self.train_labels):
            self.train_labels[i] = np.random.choice(classes, p=self.T[cl])

    def custom_t_high_noise(self):
        self.T = [[0.46, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.48, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.45, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.46, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.47, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.45, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.47, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.48, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.49, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.48]]
        classes = range(10)
        for i, cl in enumerate(self.train_labels):
            self.train_labels[i] = np.random.choice(classes, p=self.T[cl])

    def get_t(self):
        print("returnin T: ", self.T)
        return self.T

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        filtered_indexs = [item for item in self.indexs if item != index]
        index2 = random.sample(filtered_indexs, 1)[0]
        img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
        img2 = self.train_data[index2]
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        img2 = Image.fromarray(img2)

        if self.transform_aug is not None:
            img_aug = self.transform_aug(img)
        else:
            img_aug = self.transform(img)

        if self.transform is not None:
            img = self.transform(img)
            img2 = self.transform(img2)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, img2, img_aug, target, index, target_gt

    def __len__(self):
        return len(self.train_data)


class CIFAR10_val(torchvision.datasets.CIFAR10):

    def __init__(self, root, cfg_trainer, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_val, self).__init__(root, train=train,
                                          transform=transform, target_transform=target_transform,
                                          download=download)

        self.num_classes = 10
        self.cfg_trainer = cfg_trainer
        if train:
            self.train_data = self.data[indexs]
            self.train_labels = np.array(self.targets)[indexs]
        else:
            self.train_data = self.data
            self.train_labels = np.array(self.targets)
        self.train_labels_gt = self.train_labels.copy()
        self.indexs = indexs

    def symmetric_noise(self):  
        indices = np.random.permutation(len(self.train_data))
        for i, idx in enumerate(indices):
            if i < self.cfg_trainer['percent'] * len(self.train_data):
                self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)

    def asymmetric_noise(self):
        for i in range(self.num_classes):
            indices = np.where(self.train_labels == i)[0]
            np.random.shuffle(indices)
            for j, idx in enumerate(indices):
                if j < self.cfg_trainer['percent'] * len(indices):
                    # truck -> automobile
                    if i == 9:
                        self.train_labels[idx] = 1
                    # bird -> airplane
                    elif i == 2:
                        self.train_labels[idx] = 0
                    # cat -> dog
                    elif i == 3:
                        self.train_labels[idx] = 5
                    # dog -> cat
                    elif i == 5:
                        self.train_labels[idx] = 3
                    # deer -> horse
                    elif i == 4:
                        self.train_labels[idx] = 7
    
    def instance_noise(self):
        self.train_labels_gt = self.train_labels.copy()
        noise_label = torch.load(self.root + 'cifar-noisy/IDN_{:.1f}_C10.pt'.format(self.cfg_trainer['percent']))
        noisylabel = noise_label['noise_label_train'][self.indexs]
        self.train_labels = np.array(noisylabel)

    def custom_t_low_noise(self):
        self.T = [[0.82, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.83, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.81, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.823, 0.017, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.817, 0.022, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.822, 0.021, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.821, 0.018, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.818, 0.019, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.819, 0.02],
                  [0.02, 0.03, 0.01, 0.023, 0.017, 0.022, 0.021, 0.018, 0.019, 0.82]]
        classes = range(10)
        for i, cl in enumerate(self.train_labels):
            self.train_labels[i] = np.random.choice(classes, p=self.T[cl])

    def custom_t_high_noise(self):
        self.T = [[0.46, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.48, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.45, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.46, 0.06, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.47, 0.04, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.45, 0.06, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.47, 0.07, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.48, 0.08, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.49, 0.07],
                  [0.05, 0.07, 0.04, 0.05, 0.06, 0.04, 0.06, 0.07, 0.08, 0.48]]
        classes = range(10)
        for i, cl in enumerate(self.train_labels):
            self.train_labels[i] = np.random.choice(classes, p=self.T[cl])

    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index, target_gt


class CIFAR10N(data.Dataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    base_folder = 'cifar-10-batches-py'
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]

    def __init__(self, root, train=True,
                 transform=None, target_transform=None, transform_aug=None,
                 download=False,
                 noise_type=None, noise_path=None, is_human=True):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.transform_aug = transform_aug
        self.train = train  # training set or test set
        self.dataset='cifar10'
        self.noise_type=noise_type
        self.nb_classes=10
        self.noise_path = 'data_loader/data/cifar_10_100n_main/data/CIFAR-10_human.pt'
        idx_each_class_noisy = [[] for i in range(10)]
        if download:
           self.download()

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(self.root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                else:
                    self.train_labels += entry['fine_labels']
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
            #if noise_type is not None:
            if noise_type !='clean':
                # Load human noisy labels
                train_noisy_labels = self.load_label()
                self.train_noisy_labels = train_noisy_labels.tolist()
                print(f'noisy labels loaded from {self.noise_path}')

                for i in range(len(self.train_noisy_labels)):
                    idx_each_class_noisy[self.train_noisy_labels[i]].append(i)
                class_size_noisy = [len(idx_each_class_noisy[i]) for i in range(10)]
                self.noise_prior = np.array(class_size_noisy)/sum(class_size_noisy)
                print(f'The noisy data ratio in each class is {self.noise_prior}')
                self.noise_or_not = np.transpose(self.train_noisy_labels)!=np.transpose(self.train_labels)
                self.actual_noise_rate = np.sum(self.noise_or_not)/50000
                print('over all noise rate is ', self.actual_noise_rate)
        else:
            f = self.test_list[0][0]
            file = os.path.join(self.root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry:
                self.test_labels = entry['labels']
            else:
                self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC

    def load_label(self):
        #NOTE only load manual training label
        noise_label = torch.load(self.noise_path)
        if isinstance(noise_label, dict):
            if "clean_label" in noise_label.keys():
                clean_label = torch.tensor(noise_label['clean_label'])
                assert torch.sum(torch.tensor(self.train_labels) - clean_label) == 0
                print(f'Loaded {self.noise_type} from {self.noise_path}.')
                print(f'The overall noise rate is {1-np.mean(clean_label.numpy() == noise_label[self.noise_type])}')
            return noise_label[self.noise_type].reshape(-1)
        else:
            raise Exception('Input Error')

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            filtered_indexs = [item for item in range(len(self.train_data)) if item != index]
            index2 = random.sample(filtered_indexs, 1)[0]
            if self.noise_type !='clean':
                img, target = self.train_data[index], self.train_noisy_labels[index]
                img2 = self.train_data[index2]
            else:
                img, target = self.train_data[index], self.train_labels[index]
                img2 = self.train_data[index2]
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img)
            img2 = Image.fromarray(img2)
            if self.transform_aug is not None:
                img_aug = self.transform_aug(img)
            else:
                img_aug = self.transform(img)
            if self.transform is not None:
                img = self.transform(img)
                img2 = self.transform(img2)
            if self.target_transform is not None:
                target = self.target_transform(target)
            return img, img2, img_aug, target, index, 0
        else:
            img, target = self.test_data[index], self.test_labels[index]
            img = Image.fromarray(img)
            if self.transform is not None:
                img = self.transform(img)
            if self.target_transform is not None:
                target = self.target_transform(target)
        return img, target, index, 0

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_integrity(self):
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return

        root = self.root
        download_url(self.url, root, self.filename, self.tgz_md5)

        # extract file
        cwd = os.getcwd()
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

def check_integrity(fpath, md5):
    if not os.path.isfile(fpath):
        return False
    md5o = hashlib.md5()
    with open(fpath, 'rb') as f:
        # read in 1MB chunks
        for chunk in iter(lambda: f.read(1024 * 1024), b''):
            md5o.update(chunk)
    md5c = md5o.hexdigest()
    if md5c != md5:
        return False
    return True

def download_url(url, root, filename, md5):
    from six.moves import urllib

    root = os.path.expanduser(root)
    fpath = os.path.join(root, filename)

    try:
        os.makedirs(root)
    except OSError as e:
        if e.errno == errno.EEXIST:
            pass
        else:
            raise

    # downloads file
    if os.path.isfile(fpath) and check_integrity(fpath, md5):
        print('Using downloaded and verified file: ' + fpath)
    else:
        try:
            print('Downloading ' + url + ' to ' + fpath)
            urllib.request.urlretrieve(url, fpath)
        except:
            if url[:5] == 'https':
                url = url.replace('https:', 'http:')
                print('Failed download. Trying https -> http instead.'
                      ' Downloading ' + url + ' to ' + fpath)
                urllib.request.urlretrieve(url, fpath)

def list_dir(root, prefix=False):
    """List all directories at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
    directories = list(
        filter(
            lambda p: os.path.isdir(os.path.join(root, p)),
            os.listdir(root)
        )
    )

    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]

    return directories

def list_files(root, suffix, prefix=False):
    """List all files ending with a suffix at a given root

    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
    files = list(
        filter(
            lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
            os.listdir(root)
        )
    )

    if prefix is True:
        files = [os.path.join(root, d) for d in files]

    return files

# basic function#
def multiclass_noisify(y, P, random_state=0):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """
    # print np.max(y), P.shape[0]
    assert P.shape[0] == P.shape[1]
    assert np.max(y) < P.shape[0]

    # row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    m = y.shape[0]
    # print m
    new_y = y.copy()
    flipper = np.random.RandomState(random_state)
    print(f'flip with random seed {random_state}')

    for idx in np.arange(m):
        i = y[idx]
        # draw a vector with only an 1
        flipped = flipper.multinomial(1, P[i, :], 1)[0]
        new_y[idx] = np.where(flipped == 1)[0]

    return new_y

# noisify_pairflip call the function "multiclass_noisify"
def noisify_pairflip(y_train, noise, random_state=None, nb_classes=10):
    """mistakes:
        flip in the pair
    """
    P = np.eye(nb_classes)
    n = noise

    if n > 0.0:
        # 0 -> 1
        P[0, 0], P[0, 1] = 1. - n, n
        for i in range(1, nb_classes - 1):
            P[i, i], P[i, i + 1] = 1. - n, n
        P[nb_classes - 1, nb_classes - 1], P[nb_classes - 1, 0] = 1. - n, n

        y_train_noisy = multiclass_noisify(y_train, P=P,
                                           random_state=random_state)
        actual_noise = (y_train_noisy != y_train).mean()
        assert actual_noise > 0.0
        print('Actual noise %.2f' % actual_noise)
        y_train = y_train_noisy
    # print P

    return y_train, actual_noise

def noisify_multiclass_symmetric(y_train, noise, random_state=None, nb_classes=10):
    """mistakes:
        flip in the symmetric way
    """
    P = np.ones((nb_classes, nb_classes))
    n = noise
    P = (n / (nb_classes - 1)) * P

    if n > 0.0:
        # 0 -> 1
        P[0, 0] = 1. - n
        for i in range(1, nb_classes - 1):
            P[i, i] = 1. - n
        P[nb_classes - 1, nb_classes - 1] = 1. - n

        y_train_noisy = multiclass_noisify(y_train, P=P,
                                           random_state=random_state)
        actual_noise = (y_train_noisy != y_train).mean()
        assert actual_noise > 0.0
        print('Actual noise %.2f' % actual_noise)
        y_train = y_train_noisy
    # print P

    return y_train, actual_noise

def noisify(dataset='mnist', nb_classes=10, train_labels=None, noise_type=None, noise_rate=0, random_state=0):
    if noise_type == 'pairflip':
        train_noisy_labels, actual_noise_rate = noisify_pairflip(train_labels, noise_rate, random_state=0,
                                                                 nb_classes=nb_classes)
    if noise_type == 'symmetric':
        train_noisy_labels, actual_noise_rate = noisify_multiclass_symmetric(train_labels, noise_rate,
                                                                             random_state=0, nb_classes=nb_classes)
    return train_noisy_labels, actual_noise_rate

def noisify_instance(train_data, train_labels, noise_rate):
    if max(train_labels) > 10:
        num_class = 100
    else:
        num_class = 10
    np.random.seed(0)

    q_ = np.random.normal(loc=noise_rate, scale=0.1, size=1000000)
    q = []
    for pro in q_:
        if 0 < pro < 1:
            q.append(pro)
        if len(q) == 50000:
            break

    w = np.random.normal(loc=0, scale=1, size=(32 * 32 * 3, num_class))

    noisy_labels = []
    for i, sample in enumerate(train_data):
        sample = sample.flatten()
        p_all = np.matmul(sample, w)
        p_all[train_labels[i]] = -1000000
        p_all = q[i] * F.softmax(torch.tensor(p_all), dim=0).numpy()
        p_all[train_labels[i]] = 1 - q[i]
        noisy_labels.append(np.random.choice(np.arange(num_class), p=p_all / sum(p_all)))
    over_all_noise_rate = 1 - float(torch.tensor(train_labels).eq(torch.tensor(noisy_labels)).sum()) / 50000
    return noisy_labels, over_all_noise_rate
