from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
from copy import deepcopy
from torchvision import transforms as transforms
import sys
from PIL import Image

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

import torch.utils.data as data
from .utils import download_url, check_integrity, noisify
import matplotlib.pyplot as plt

PATCH_SIZE = 3
PATCH = np.array([[255, 0, 255], [0, 255, 0], [255, 0, 255]])
PATCH1 = np.random.rand(PATCH_SIZE, PATCH_SIZE, 3)
for i in range(3):
    PATCH1[:, :, i] = PATCH
PATCH1 = PATCH1.astype(int)

IMG2Blend = np.random.rand(32, 32, 3) * 255
IMG2Blend = IMG2Blend.astype(int)


def SquarePatchAttack(image, patch, attack_index=None):
    assert len(image.shape) == 4
    h, w, _ = patch.shape
    if attack_index is None:
        image[:, -h:, -w:, :] = patch
    else:
        image[attack_index, -h:, -w:, :] = patch
    return image


def BlendAttack(image, image2blend, alpha=0.2, attack_index=None):
    assert len(image.shape) == 4
    assert image2blend.shape == image[0, :, :, :].shape
    # print(type(image))
    if attack_index is None:
        image = (1 - alpha) * image + alpha * image2blend
    else:
        image[attack_index, :, :, :] = (1 - alpha) * image[
            attack_index, :, :, :
        ] + alpha * image2blend
    return image.astype("uint8")


def ContrastAttack(image, contrast, attack_index=None):
    assert len(image.shape) == 4
    if attack_index is None:
        for i in range(image.shape[0]):
            img = Image.fromarray(np.uint8(image[i, :, :, :]))
            img = transforms.ColorJitter(contrast=contrast)(img)
            img = np.array(img)
            image[i, :, :, :] = img
    else:
        for i in attack_index:
            img = Image.fromarray(np.uint8(image[i, :, :, :]))
            img = transforms.ColorJitter(contrast=contrast)(img)
            img = np.array(img)
            image[i, :, :, :] = img
    return image



class CIFAR10(data.Dataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]

    def __init__(
        self,
        root,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        noise_type=None,
        noise_rate=0.2,
        random_state=0,
        # contrast_value=5
    ):
        if noise_type not in ["patch", "contrast", "blend", "symmetric", "pairflip"]:
            raise ValueError(
                "noise_type has to be one of backdoor (patch, contrast), symmetric, and pairflip"
            )

        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.dataset = "cifar10"
        self.noise_type = noise_type
        self.nb_classes = 10

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted."
                + " You can use download=True to download it"
            )

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(self.root, self.base_folder, f)
                fo = open(file, "rb")
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding="latin1")
                self.train_data.append(entry["data"])
                if "labels" in entry:
                    self.train_labels += entry["labels"]
                else:
                    self.train_labels += entry["fine_labels"]
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
            # if noise_type is not None:

            # np.digitize(self.train_data,[])
            if noise_type != "clean" and noise_type not in [
                "patch",
                "contrast",
                "blend",
            ]:
                # noisify train data
                self.train_labels = np.asarray(
                    [[self.train_labels[i]] for i in range(len(self.train_labels))]
                )
                self.train_noisy_labels, self.actual_noise_rate = noisify(
                    dataset=self.dataset,
                    train_labels=self.train_labels,
                    noise_type=noise_type,
                    noise_rate=noise_rate,
                    random_state=random_state,
                    nb_classes=self.nb_classes,
                )
                self.train_noisy_labels = [i[0] for i in self.train_noisy_labels]
                _train_labels = [i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(
                    self.train_noisy_labels
                ) == np.transpose(_train_labels)

            elif noise_type in ["patch", "contrast", "blend", "target"]:
                self.train_labels = np.asarray(
                    [[self.train_labels[i]] for i in range(len(self.train_labels))]
                )
                self.train_noisy_labels, self.actual_noise_rate = noisify(
                    dataset=self.dataset,
                    train_labels=self.train_labels,
                    noise_type=noise_type,
                    noise_rate=noise_rate,
                    random_state=random_state,
                    nb_classes=self.nb_classes,
                    # concentration=False
                )
                self.train_noisy_labels = [i[0] for i in self.train_noisy_labels]
                _train_labels = [i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(
                    self.train_noisy_labels
                ) == np.transpose(_train_labels)
                _noise_index = np.where(
                    np.transpose(self.train_noisy_labels) != np.transpose(_train_labels)
                )[0]
                self.train_noisy_labels = np.array(self.train_noisy_labels)
                # print(self.train_noisy_labels)
                # self.train_noisy_labels[_noise_index] = 0
                self.train_noisy_labels = self.train_noisy_labels.tolist()
                # self.train_noisy_data = deepcopy(self.train_data)
                plt.subplot(121)
                plt.imshow(self.train_data[_noise_index[0], :, :, :])
                plt.title("clean")
                if noise_type == "contrast":
                    self.train_data = ContrastAttack(
                        self.train_data, 5, attack_index=_noise_index
                    )
                elif noise_type == "patch":
                    self.train_data = SquarePatchAttack(
                        self.train_data, patch=PATCH1, attack_index=_noise_index
                    )
                elif noise_type == "blend":
                    self.train_data = BlendAttack(
                        self.train_data,
                        image2blend=IMG2Blend,
                        alpha=0.1,
                        attack_index=_noise_index,
                    )

                plt.subplot(122)
                plt.imshow(self.train_data[_noise_index[0], :, :, :])
                plt.title("poisoned")
                plt.show()

        else:
            f = self.test_list[0][0]
            file = os.path.join(self.root, self.base_folder, f)
            fo = open(file, "rb")
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding="latin1")
            self.test_data = entry["data"]
            if "labels" in entry:
                self.test_labels = entry["labels"]
            else:
                self.test_labels = entry["fine_labels"]
            fo.close()
            test_label_temp = np.asarray(
                [[self.test_labels[i]] for i in range(len(self.test_labels))]
            )
            self.test_noise_labels, self.actual_noise_rate = noisify(
                dataset=self.dataset,
                train_labels=test_label_temp,
                noise_type=noise_type,
                noise_rate=1,
                random_state=random_state,
                nb_classes=self.nb_classes,
            )
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
            self.test_data_attack = deepcopy(self.test_data)
            # self.test_noise_labels = np.ones_like(self.test_noise_labels)*0
            # print(self.test_noise_labels)
            # for i in range(10000):
            # self.test_data_attack[:, -PATCH_SIZE:, -PATCH_SIZE:, :] = PATCH1
            if self.noise_type == "contrast":
                self.test_data_attack = ContrastAttack(self.test_data_attack, 0.1)
            elif self.noise_type == "patch":
                self.test_data_attack = SquarePatchAttack(
                    self.test_data_attack, patch=PATCH1
                )
            elif self.noise_type == "blend":
                self.test_data_attack = BlendAttack(
                    self.test_data_attack,
                    image2blend=IMG2Blend,
                    alpha=0.1,
                )

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            if self.noise_type != "clean":
                img, target = self.train_data[index], self.train_noisy_labels[index]
            else:
                img, target = self.train_data[index], self.train_labels[index]
            poisoned_img = None
            target_poisoned_label = None
        else:
            img, target = self.test_data[index], self.test_labels[index]
            poisoned_img = self.test_data_attack[index]
            target_poisoned_label = self.test_noise_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if poisoned_img is not None:
            poisoned_img = Image.fromarray(poisoned_img)

        if self.transform is not None:
            img = self.transform(img)
            if poisoned_img is not None:
                poisoned_img = self.transform(poisoned_img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.train:
            return img, target, index
        else:
            return (img, poisoned_img), (target, target_poisoned_label), index

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_integrity(self):
        root = self.root
        for fentry in self.train_list + self.test_list:
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        root = self.root
        download_url(self.url, root, self.filename, self.tgz_md5)

        # extract file
        cwd = os.getcwd()
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = "train" if self.train is True else "test"
        fmt_str += "    Split: {}\n".format(tmp)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str


class CIFAR100(data.Dataset):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """

    base_folder = "cifar-100-python"
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]

    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]

    def __init__(
        self,
        root,
        train=True,
        transform=None,
        target_transform=None,
        download=False,
        noise_type=None,
        noise_rate=0.2,
        random_state=0,
    ):

        if noise_type not in ["patch", "contrast", "blend", "symmetric", "pairflip"]:
            raise ValueError(
                "noise_type has to be one of backdoor (patch, contrast, blend), symmetric, and pairflip"
            )

        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.dataset = "cifar100"
        self.noise_type = noise_type
        self.nb_classes = 100

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted."
                + " You can use download=True to download it"
            )

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(self.root, self.base_folder, f)
                fo = open(file, "rb")
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding="latin1")
                self.train_data.append(entry["data"])
                if "labels" in entry:
                    self.train_labels += entry["labels"]
                else:
                    self.train_labels += entry["fine_labels"]
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
            if noise_type != "clean" and noise_type not in [
                "patch",
                "contrast",
                "blend"
            ]:
                # noisify train data
                self.train_labels = np.asarray(
                    [[self.train_labels[i]] for i in range(len(self.train_labels))]
                )
                self.train_noisy_labels, self.actual_noise_rate = noisify(
                    dataset=self.dataset,
                    train_labels=self.train_labels,
                    noise_type=noise_type,
                    noise_rate=noise_rate,
                    random_state=random_state,
                    nb_classes=self.nb_classes,
                )
                self.train_noisy_labels = [i[0] for i in self.train_noisy_labels]
                _train_labels = [i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(
                    self.train_noisy_labels
                ) == np.transpose(_train_labels)
            elif noise_type in ["patch", "contrast", "blend"]:
                self.train_labels = np.asarray(
                    [[self.train_labels[i]] for i in range(len(self.train_labels))]
                )
                self.train_noisy_labels, self.actual_noise_rate = noisify(
                    dataset=self.dataset,
                    train_labels=self.train_labels,
                    noise_type=noise_type,
                    noise_rate=noise_rate,
                    random_state=random_state,
                    nb_classes=self.nb_classes,
                )
                self.train_noisy_labels = [i[0] for i in self.train_noisy_labels]
                _train_labels = [i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(
                    self.train_noisy_labels
                ) == np.transpose(_train_labels)
                _noise_index = np.where(
                    np.transpose(self.train_noisy_labels) != np.transpose(_train_labels)
                )[0]
                # self.train_noisy_data = deepcopy(self.train_data)
                plt.subplot(121)
                plt.imshow(self.train_data[_noise_index[0], :, :, :])
                plt.title("clean")
                if noise_type == "contrast":
                    self.train_data = ContrastAttack(
                        self.train_data, 5, attack_index=_noise_index
                    )
                elif noise_type == "patch":
                    self.train_data = SquarePatchAttack(
                        self.train_data, patch=PATCH1, attack_index=_noise_index
                    )
                elif noise_type == "blend":
                    self.train_data = BlendAttack(
                        self.train_data,
                        image2blend=IMG2Blend,
                        alpha=0.1,
                        attack_index=_noise_index,
                    )

                plt.subplot(122)
                plt.imshow(self.train_data[_noise_index[0], :, :, :])
                plt.title("poisoned")
                plt.show()

        else:
            f = self.test_list[0][0]
            file = os.path.join(self.root, self.base_folder, f)
            fo = open(file, "rb")
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding="latin1")
            self.test_data = entry["data"]
            if "labels" in entry:
                self.test_labels = entry["labels"]
            else:
                self.test_labels = entry["fine_labels"]
            fo.close()
            test_label_temp = np.asarray(
                [[self.test_labels[i]] for i in range(len(self.test_labels))]
            )
            self.test_noise_labels, self.actual_noise_rate = noisify(
                dataset=self.dataset,
                train_labels=test_label_temp,
                noise_type=noise_type,
                noise_rate=1,
                random_state=random_state,
                nb_classes=self.nb_classes,
            )
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
            self.test_data_attack = deepcopy(self.test_data)

            if self.noise_type == "contrast":
                self.test_data_attack = ContrastAttack(self.test_data_attack, 0.1)
            elif self.noise_type == "patch":
                self.test_data_attack = SquarePatchAttack(
                    self.test_data_attack, patch=PATCH1
                )
            elif self.noise_type == "blend":
                self.test_data_attack = BlendAttack(
                    self.test_data_attack,
                    image2blend=IMG2Blend,
                    alpha=0.1,
                )

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            if self.noise_type is not None:
                img, target = self.train_data[index], self.train_noisy_labels[index]
            else:
                img, target = self.train_data[index], self.train_labels[index]
            poisoned_img = None
            target_poisoned_label = None
        else:
            img, target = self.test_data[index], self.test_labels[index]
            poisoned_img = self.test_data_attack[index]
            target_poisoned_label = self.test_noise_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if poisoned_img is not None:
            poisoned_img = Image.fromarray(poisoned_img)


        if self.transform is not None:
            img = self.transform(img)
            if poisoned_img is not None:
                poisoned_img = self.transform(poisoned_img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        if self.train:
            return img, target, index
        else:
            return (img, poisoned_img), (target, target_poisoned_label), index

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_integrity(self):
        root = self.root
        for fentry in self.train_list + self.test_list:
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        root = self.root
        download_url(self.url, root, self.filename, self.tgz_md5)

        # extract file
        cwd = os.getcwd()
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = "train" if self.train is True else "test"
        fmt_str += "    Split: {}\n".format(tmp)
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        return fmt_str
