from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
from .utils import noisify
import matplotlib.pyplot as plt
from copy import deepcopy

PATCH_SIZE = 3
PATCH = np.array([[255, 0, 255], [0, 255, 0], [255, 0, 255]])
PATCH1 = np.random.rand(PATCH_SIZE, PATCH_SIZE, 1)
for i in range(1):
    PATCH1[:, :, i] = PATCH
PATCH1 = PATCH1.astype(int)

IMG2Blend = np.random.rand(28, 28) * 255
IMG2Blend = IMG2Blend.astype(int)


def SquarePatchAttack(image, patch, attack_index=None):
    assert len(image.shape) == 3
    h, w, _ = patch.shape

    if attack_index is None:
        image[:, -h:, -w:] = patch.squeeze()
    else:
        image[attack_index, -h:, -w:] = patch.squeeze()
    return image


def BlendAttack(image, image2blend, alpha=0.2, attack_index=None):
    assert len(image.shape) == 3
    # print(image2blend.shape, image[0, :, :].shape)
    assert image2blend.shape == image[0, :, :].shape
    # print(type(image), type(image2blend))
    # image = image.data.numpy()
    # print(image.shape, image2blend.shape)
    if attack_index is None:
        image = (1 - alpha) * image + alpha * np.expand_dims(image2blend, 0)
    else:
        image[attack_index, :, :] = (1 - alpha) * image[
            attack_index, :, :
        ] + alpha * np.expand_dims(image2blend, 0)
    return image.astype("uint8")


def ContrastAttack(image, contrast, attack_index=None):
    assert len(image.shape) == 4
    if attack_index is None:
        for i in range(image.shape[0]):
            img = Image.fromarray(np.uint8(image[i, :, :, :]))
            img = T.ColorJitter(contrast=contrast)(img)
            img = np.array(img)
            image[i, :, :, :] = img
    else:
        for i in attack_index:
            img = Image.fromarray(np.uint8(image[i, :, :, :]))
            img = T.ColorJitter(contrast=contrast)(img)
            img = np.array(img)
            image[i, :, :, :] = img
    return image


class MNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False,
                 noise_type=None, noise_rate=0.2, random_state=0):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.dataset='mnist'
        self.noise_type=noise_type

        if noise_type not in ["patch", "contrast", "blend", "symmetric", "pairflip"]:
            raise ValueError(
                "noise_type has to be one of backdoor (patch, contrast), symmetric, and pairflip"
            )

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))

            if noise_type != 'clean' and noise_type not in ['blend', 'patch']:
                self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))])
                self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state)
                self.train_noisy_labels=[i[0] for i in self.train_noisy_labels]
                _train_labels=[i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels)

            elif noise_type in ["patch", "contrast", "blend"]:
                self.train_labels = np.asarray(
                    [[self.train_labels[i]] for i in range(len(self.train_labels))]
                )
                self.train_noisy_labels, self.actual_noise_rate = noisify(
                    dataset=self.dataset,
                    train_labels=self.train_labels,
                    noise_type=noise_type,
                    noise_rate=noise_rate,
                    random_state=random_state,
                    nb_classes=10,
                )
                self.train_noisy_labels = [i[0] for i in self.train_noisy_labels]
                _train_labels = [i[0] for i in self.train_labels]
                self.noise_or_not = np.transpose(
                    self.train_noisy_labels
                ) == np.transpose(_train_labels)
                _noise_index = np.where(
                    np.transpose(self.train_noisy_labels) != np.transpose(_train_labels)
                )[0]
                # self.train_noisy_data = deepcopy(self.train_data)
                # print(self.train_data.shape)
                # print(self.train_data.shape, type(self.train_data))
                self.train_data = self.train_data.cpu().detach().numpy()
                # print(self.train_noisy_labels.shape, type(self.train_noisy_labels))
                plt.subplot(121)
                plt.imshow(self.train_data[_noise_index[0], :, :])
                plt.title("clean")
                if noise_type == "contrast":
                    self.train_data = ContrastAttack(
                        self.train_data, 5, attack_index=_noise_index
                    )
                elif noise_type == "patch":
                    self.train_data = SquarePatchAttack(
                        self.train_data, patch=PATCH1, attack_index=_noise_index
                    )
                elif noise_type == "blend":
                    self.train_data = BlendAttack(
                        self.train_data,
                        image2blend=IMG2Blend,
                        alpha=0.1,
                        attack_index=_noise_index,
                    )

                plt.subplot(122)

                plt.imshow(self.train_data[_noise_index[0], :, :])
                plt.title("poisoned")
                plt.show()
                print("finish training data")
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
            self.test_data = self.test_data.data.numpy()
            self.test_labels = self.test_labels.data.numpy()
            test_label_temp = np.asarray(
                [[self.test_labels[i]] for i in range(len(self.test_labels))]
            )
            self.test_noise_labels, self.actual_noise_rate = noisify(
                dataset=self.dataset,
                train_labels=test_label_temp,
                noise_type=noise_type,
                noise_rate=1,
                random_state=random_state,
                nb_classes=10,
            )
            self.test_data = self.test_data.reshape((10000, 28, 28))
            # self.test_data = self.test_data.cpu().detach().numpy()
            # self.test_data = self.test_data.transpose(0, 1, 2)  # convert to HWC
            # print(self.test_data.shape)
            self.test_data_attack = deepcopy(self.test_data)
            # for i in range(10000):
            # self.test_data_attack[:, -PATCH_SIZE:, -PATCH_SIZE:, :] = PATCH1
            if self.noise_type == "contrast":
                self.test_data_attack = ContrastAttack(self.test_data_attack, 0.1)
            elif self.noise_type == "patch":
                self.test_data_attack = SquarePatchAttack(
                    self.test_data_attack, patch=PATCH1
                )
            elif self.noise_type == "blend":
                self.test_data_attack = BlendAttack(
                    self.test_data_attack,
                    image2blend=IMG2Blend,
                    alpha=0.1,
                )

            print(type(self.test_data), type(self.test_data_attack))
            print(self.test_data.shape, self.test_data_attack.shape)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            if self.noise_type != "clean":
                img, target = self.train_data[index], self.train_noisy_labels[index]
            else:
                img, target = self.train_data[index], self.train_labels[index]
            poisoned_img = None
            target_poisoned_label = None
        else:
            img, target = self.test_data[index], self.test_labels[index]
            poisoned_img = self.test_data_attack[index]
            target_poisoned_label = self.test_noise_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
        if poisoned_img is not None:
            poisoned_img = Image.fromarray(poisoned_img)

        if self.transform is not None:
            img = self.transform(img)
            if poisoned_img is not None:
                poisoned_img = self.transform(poisoned_img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.train:
            return img, target, index
        else:
            return (img, poisoned_img), (target, target_poisoned_label), index

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long()


def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols)