# -*- coding: utf-8 -*-

import os
from PIL import Image
import os.path
import time
import torch
import torchvision.datasets as dset
import torchvision.transforms as trn
import torch.utils.data as data
import numpy as np
import torch.distributions.dirichlet as dirichlet
import random

import pickle

from PIL import Image

from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import numpy as np
import sys
import os
from PIL import Image


def save_data(l, path_):
    with open(path_, 'wb') as f:
        pickle.dump(l, f)


class TinyImageNet(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.Train = train
        self.root_dir = root
        self.transform = transform
        self.train_dir = os.path.join(self.root_dir, "train")
        self.val_dir = os.path.join(self.root_dir, "val")

        if (self.Train):
            self._create_class_idx_dict_train()
        else:
            self._create_class_idx_dict_val()

        self._make_dataset(self.Train)

        words_file = os.path.join(self.root_dir, "words.txt")
        wnids_file = os.path.join(self.root_dir, "wnids.txt")

        self.set_nids = set()

        with open(wnids_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                self.set_nids.add(entry.strip("\n"))

        self.class_to_label = {}
        with open(words_file, 'r') as fo:
            data = fo.readlines()
            for entry in data:
                words = entry.split("\t")
                if words[0] in self.set_nids:
                    self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]

    def _create_class_idx_dict_train(self):
        if sys.version_info >= (3, 5):
            classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir, d))]
        classes = sorted(classes)
        num_images = 0
        for root, dirs, files in os.walk(self.train_dir):
            for f in files:
                if f.endswith(".JPEG"):
                    num_images = num_images + 1

        self.len_dataset = num_images;

        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}

    def _create_class_idx_dict_val(self):
        val_image_dir = os.path.join(self.val_dir, "images")
        if sys.version_info >= (3, 5):
            images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
        else:
            images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir, d))]
        val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
        self.val_img_to_class = {}
        set_of_classes = set()
        with open(val_annotations_file, 'r') as fo:
            entry = fo.readlines()
            for data in entry:
                words = data.split("\t")
                self.val_img_to_class[words[0]] = words[1]
                set_of_classes.add(words[1])

        self.len_dataset = len(list(self.val_img_to_class.keys()))
        classes = sorted(list(set_of_classes))
        # self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))}
        self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
        self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}

    def _make_dataset(self, Train=True):
        self.images = []
        if Train:
            img_root_dir = self.train_dir
            list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
        else:
            img_root_dir = self.val_dir
            list_of_dirs = ["images"]

        for tgt in list_of_dirs:
            dirs = os.path.join(img_root_dir, tgt)
            if not os.path.isdir(dirs):
                continue

            for root, _, files in sorted(os.walk(dirs)):
                for fname in sorted(files):
                    if (fname.endswith(".JPEG")):
                        path = os.path.join(root, fname)
                        if Train:
                            item = (path, self.class_to_tgt_idx[tgt])
                        else:
                            item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
                        self.images.append(item)

    def return_label(self, idx):
        return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        img_path, tgt = self.images[idx]
        with open(img_path, 'rb') as f:
            sample = Image.open(img_path)
            sample = sample.convert('RGB')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, tgt


# /////////////// Data Loader ///////////////


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']


def is_image_file(filename):
    """Checks if a file is an image.
    Args:
        filename (string): path to a file
    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)


def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(dir, class_to_idx):
    images = []
    dir = os.path.expanduser(dir)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue

        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if is_image_file(fname):
                    path = os.path.join(root, fname)
                    item = (path, class_to_idx[target])
                    images.append(item)

    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class DistortImageFolder(data.Dataset):
    def __init__(self, root, method, severity, transform=None, target_transform=None,
                 loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in subfolders of: " + root + "\n" +
                                "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.method = method
        self.severity = severity
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.idx_to_class = {v: k for k, v in class_to_idx.items()}
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
            img = self.method(img, self.severity)
        if self.target_transform is not None:
            target = self.target_transform(target)

        save_path = '/share/data/lang/users/dan/Tiny-ImageNet-C/' + self.method.__name__ + \
                    '/' + str(self.severity) + '/' + self.idx_to_class[target]

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        save_path += path[path.rindex('/'):]

        Image.fromarray(np.uint8(img)).save(save_path, quality=85, optimize=True)

        return 0  # we do not care about returning the data

    def __len__(self):
        return len(self.imgs)


# /////////////// Distortion Helpers ///////////////

import skimage as sk
from skimage.filters import gaussian
from io import BytesIO
from wand.image import Image as WandImage
from wand.api import library as wandlibrary
import wand.color as WandColor
import ctypes
from PIL import Image as PILImage
import cv2
from scipy.ndimage import zoom as scizoom
from scipy.ndimage.interpolation import map_coordinates
import warnings

warnings.simplefilter("ignore", UserWarning)


def auc(errs):  # area under the alteration error curve
    area = 0
    for i in range(1, len(errs)):
        area += (errs[i] + errs[i - 1]) / 2
    area /= len(errs) - 1
    return area


def disk(radius, alias_blur=0.1, dtype=np.float32):
    if radius <= 8:
        L = np.arange(-8, 8 + 1)
        ksize = (3, 3)
    else:
        L = np.arange(-radius, radius + 1)
        ksize = (5, 5)
    X, Y = np.meshgrid(L, L)
    aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
    aliased_disk /= np.sum(aliased_disk)

    # supersample disk to antialias
    return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)


# Tell Python about the C method
wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p,  # wand
                                              ctypes.c_double,  # radius
                                              ctypes.c_double,  # sigma
                                              ctypes.c_double)  # angle


# Extend wand.image.Image class to include method signature
class MotionImage(WandImage):
    def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
        wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)


# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
def plasma_fractal(mapsize=64, wibbledecay=3):
    """
    Generate a heightmap using diamond-square algorithm.
    Return square 2d array, side length 'mapsize', of floats in range 0-255.
    'mapsize' must be a power of two.
    """
    assert (mapsize & (mapsize - 1) == 0)
    maparray = np.empty((mapsize, mapsize), dtype=np.float_)
    maparray[0, 0] = 0
    stepsize = mapsize
    wibble = 100

    def wibbledmean(array):
        return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)

    def fillsquares():
        """For each square of points stepsize apart,
           calculate middle value as mean of points + wibble"""
        cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
        squareaccum += np.roll(squareaccum, shift=-1, axis=1)
        maparray[stepsize // 2:mapsize:stepsize,
        stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)

    def filldiamonds():
        """For each diamond of points stepsize apart,
           calculate middle value as mean of points + wibble"""
        mapsize = maparray.shape[0]
        drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
        ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
        lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
        ltsum = ldrsum + lulsum
        maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
        tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
        tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
        ttsum = tdrsum + tulsum
        maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)

    while stepsize >= 2:
        fillsquares()
        filldiamonds()
        stepsize //= 2
        wibble /= wibbledecay

    maparray -= maparray.min()
    return maparray / maparray.max()


def clipped_zoom(img, zoom_factor):
    h = img.shape[0]
    # ceil crop height(= crop width)
    ch = int(np.ceil(h / zoom_factor))

    top = (h - ch) // 2
    img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
    # trim off any extra pixels
    trim_top = (img.shape[0] - h) // 2

    return img[trim_top:trim_top + h, trim_top:trim_top + h]


# /////////////// End Distortion Helpers ///////////////


# /////////////// Distortions ///////////////

def gaussian_noise(x, severity=1):
    c = [0.04, 0.08, .12, .15, .18][severity - 1]

    x = np.array(x) / 255.
    return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255


def shot_noise(x, severity=1):
    c = [250, 100, 50, 30, 15][severity - 1]

    x = np.array(x) / 255.
    return np.clip(np.random.poisson(x * c) / c, 0, 1) * 255


def impulse_noise(x, severity=1):
    c = [.01, .02, .05, .08, .14][severity - 1]

    x = sk.util.random_noise(np.array(x) / 255., mode='s&p', amount=c)
    return np.clip(x, 0, 1) * 255


def speckle_noise(x, severity=1):
    c = [.15, .2, 0.25, 0.3, 0.35][severity - 1]

    x = np.array(x) / 255.
    return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1) * 255


def gaussian_blur(x, severity=1):
    c = [.5, .75, 1, 1.25, 1.5][severity - 1]

    x = gaussian(np.array(x) / 255., sigma=c, multichannel=True)
    return np.clip(x, 0, 1) * 255


def glass_blur(x, severity=1):
    # sigma, max_delta, iterations
    c = [(0.1,1,1), (0.5,1,1), (0.6,1,2), (0.7,2,1), (0.9,2,2)][severity - 1]

    x = np.uint8(gaussian(np.array(x) / 255., sigma=c[0], multichannel=True) * 255)

    # locally shuffle pixels
    for i in range(c[2]):
        for h in range(64 - c[1], c[1], -1):
            for w in range(64 - c[1], c[1], -1):
                dx, dy = np.random.randint(-c[1], c[1], size=(2,))
                h_prime, w_prime = h + dy, w + dx
                # swap
                x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]

    return np.clip(gaussian(x / 255., sigma=c[0], multichannel=True), 0, 1) * 255


def defocus_blur(x, severity=1):
    c = [(0.5, 0.6), (1, 0.1), (1.5, 0.1), (2.5, 0.01), (3, 0.1)][severity - 1]

    x = np.array(x) / 255.
    kernel = disk(radius=c[0], alias_blur=c[1])

    channels = []
    for d in range(3):
        channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
    channels = np.array(channels).transpose((1, 2, 0))  # 3x64x64 -> 64x64x3

    return np.clip(channels, 0, 1) * 255


def motion_blur(x, severity=1):
    c = [(10,1), (10,1.5), (10,2), (10,2.5), (12,3)][severity - 1]

    output = BytesIO()
    x.save(output, format='PNG')
    x = MotionImage(blob=output.getvalue())

    x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))

    x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
                     cv2.IMREAD_UNCHANGED)

    if x.shape != (64, 64):
        return np.clip(x[..., [2, 1, 0]], 0, 255)  # BGR to RGB
    else:  # greyscale to RGB
        return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)


def zoom_blur(x, severity=1):
    c = [np.arange(1, 1.06, 0.01), np.arange(1, 1.11, 0.01), np.arange(1, 1.16, 0.01),
         np.arange(1, 1.21, 0.01), np.arange(1, 1.26, 0.01)][severity - 1]

    x = (np.array(x) / 255.).astype(np.float32)
    out = np.zeros_like(x)
    for zoom_factor in c:
        out += clipped_zoom(x, zoom_factor)

    x = (x + out) / (len(c) + 1)
    return np.clip(x, 0, 1) * 255


def fog(x, severity=1):
    c = [(.4,3), (.7,3), (1,2.5), (1.5,2), (2,1.75)][severity - 1]

    x = np.array(x) / 255.
    max_val = x.max()
    x += c[0] * plasma_fractal(wibbledecay=c[1])[:64, :64][..., np.newaxis]
    return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255


def frost(x, severity=1):
    c = [(1, 0.3), (0.9, 0.4), (0.8, 0.45), (0.75, 0.5), (0.7, 0.55)][severity - 1]
    idx = np.random.randint(5)
    filename = ['./create_c/frost1.png', './create_c/frost2.png', './create_c/frost3.png', './create_c/frost4.jpg', './create_c/frost5.jpg', './create_c/frost6.jpg'][idx]
    frost = cv2.imread(filename)
    frost = cv2.resize(frost, (0, 0), fx=0.3, fy=0.3)
    # randomly crop and convert to rgb
    x_start, y_start = np.random.randint(0, frost.shape[0] - 64), np.random.randint(0, frost.shape[1] - 64)
    frost = frost[x_start:x_start + 64, y_start:y_start + 64][..., [2, 1, 0]]

    return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 255)


def snow(x, severity=1):
    c = [(0.1,0.2,1,0.6,8,3,0.8),
         (0.1,0.2,1,0.5,10,4,0.8),
         (0.15,0.3,1.75,0.55,10,4,0.7),
         (0.25,0.3,2.25,0.6,12,6,0.65),
         (0.3,0.3,1.25,0.65,14,12,0.6)][severity - 1]

    x = np.array(x, dtype=np.float32) / 255.
    snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])  # [:2] for monochrome

    snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
    snow_layer[snow_layer < c[3]] = 0

    snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
    output = BytesIO()
    snow_layer.save(output, format='PNG')
    snow_layer = MotionImage(blob=output.getvalue())

    snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))

    snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
                              cv2.IMREAD_UNCHANGED) / 255.
    snow_layer = snow_layer[..., np.newaxis]

    x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(64, 64, 1) * 1.5 + 0.5)
    return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255


def spatter(x, severity=1):
    c = [(0.62,0.1,0.7,0.7,0.6,0),
         (0.65,0.1,0.8,0.7,0.6,0),
         (0.65,0.3,1,0.69,0.6,0),
         (0.65,0.1,0.7,0.68,0.6,1),
         (0.65,0.1,0.5,0.67,0.6,1)][severity - 1]
    x = np.array(x, dtype=np.float32) / 255.

    liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])

    liquid_layer = gaussian(liquid_layer, sigma=c[2])
    liquid_layer[liquid_layer < c[3]] = 0
    if c[5] == 0:
        liquid_layer = (liquid_layer * 255).astype(np.uint8)
        dist = 255 - cv2.Canny(liquid_layer, 50, 150)
        dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
        _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
        dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
        dist = cv2.equalizeHist(dist)
        #     ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
        #     ker -= np.mean(ker)
        ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
        dist = cv2.filter2D(dist, cv2.CV_8U, ker)
        dist = cv2.blur(dist, (3, 3)).astype(np.float32)

        m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
        m /= np.max(m, axis=(0, 1))
        m *= c[4]

        # water is pale turqouise
        color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]),
                                238 / 255. * np.ones_like(m[..., :1]),
                                238 / 255. * np.ones_like(m[..., :1])), axis=2)

        color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
        x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)

        return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
    else:
        m = np.where(liquid_layer > c[3], 1, 0)
        m = gaussian(m.astype(np.float32), sigma=c[4])
        m[m < 0.8] = 0
        #         m = np.abs(m) ** (1/c[4])

        # mud brown
        color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]),
                                42 / 255. * np.ones_like(x[..., :1]),
                                20 / 255. * np.ones_like(x[..., :1])), axis=2)

        color *= m[..., np.newaxis]
        x *= (1 - m[..., np.newaxis])

        return np.clip(x + color, 0, 1) * 255


def contrast(x, severity=1):
    c = [.4, .3, .2, .1, 0.05][severity - 1]

    x = np.array(x) / 255.
    means = np.mean(x, axis=(0, 1), keepdims=True)
    return np.clip((x - means) * c + means, 0, 1) * 255


def brightness(x, severity=1):
    c = [.1, .2, .3, .4, .5][severity - 1]

    x = np.array(x) / 255.
    x = sk.color.rgb2hsv(x)
    x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
    x = sk.color.hsv2rgb(x)

    return np.clip(x, 0, 1) * 255


def saturate(x, severity=1):
    c = [(0.3, 0), (0.1, 0), (2, 0), (5, 0.1), (30, 0.2)][severity - 1]

    x = np.array(x) / 255.
    x = sk.color.rgb2hsv(x)
    x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
    x = sk.color.hsv2rgb(x)

    return np.clip(x, 0, 1) * 255


def jpeg_compression(x, severity=1):
    c = [65, 58, 50, 40, 25][severity - 1]

    output = BytesIO()
    x.save(output, 'JPEG', quality=c)
    x = PILImage.open(output)

    return x


def pixelate(x, severity=1):
    c = [0.9, 0.8, 0.7, 0.6, 0.5][severity - 1]

    x = x.resize((int(64 * c), int(64 * c)), PILImage.BOX)
    x = x.resize((64, 64), PILImage.BOX)

    return x


# mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5
def elastic_transform(image, severity=1):
    IMSIZE = 64
    c = [(IMSIZE*0, IMSIZE*0, IMSIZE*0.08),
         (IMSIZE*0.05, IMSIZE*0.3, IMSIZE*0.06),
         (IMSIZE*0.1, IMSIZE*0.08, IMSIZE*0.06),
         (IMSIZE*0.1, IMSIZE*0.03, IMSIZE*0.03),
         (IMSIZE*0.16, IMSIZE*0.03, IMSIZE*0.02)][severity - 1]

    image = np.array(image, dtype=np.float32) / 255.
    shape = image.shape
    shape_size = shape[:2]

    # random affine
    center_square = np.float32(shape_size) // 2
    square_size = min(shape_size) // 3
    pts1 = np.float32([center_square + square_size,
                       [center_square[0] + square_size, center_square[1] - square_size],
                       center_square - square_size])
    pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
    M = cv2.getAffineTransform(pts1, pts2)
    image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)

    dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
                   c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
    dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
                   c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
    dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]

    x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
    indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
    return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1) * 255


# /////////////// End Distortions ///////////////


# /////////////// Further Setup ///////////////


def save_distorted(method=gaussian_noise):
    for severity in range(1, 6):
        print(method.__name__, severity)
        distorted_dataset = DistortImageFolder(
            root="./imagenet_val_bbox_crop/",
            method=method, severity=severity,
            transform=trn.Compose([trn.Resize((64, 64))]))
        distorted_dataset_loader = torch.utils.data.DataLoader(
            distorted_dataset, batch_size=100, shuffle=False, num_workers=6)

        for _ in distorted_dataset_loader: continue


# /////////////// End Further Setup ///////////////


# /////////////// Display Results ///////////////
import collections

print('\nUsing ImageNet data')


def divide_by_label(dataset, class_num = 10):
    index_map = [[] for i in range(class_num)]
    len_map = [0 for _ in range(class_num)]
    for i in range(len(dataset)):
        index_map[dataset[i][1]].append(i)
        len_map[dataset[i][1]] += 1
    return index_map, len_map

def reweight(q, empty_class):
    # sum_q = sum(q)
    q[empty_class] = 0
    q = q / sum(q)
    return q

C = 200

def get_noniid_class_and_labels(original_images, original_labels, N):
    
        M = len(original_labels) // N

        clients_images = [[] for _ in range(N)]
        clients_labels = [[] for _ in range(N)]
        classes_by_index = [[] for _ in range(C)]
        classes_by_index_len = [0 for _ in range(C)]
        for i, label in enumerate(original_labels):
            classes_by_index[label].append(i)
            classes_by_index_len[label] += 1

        
        for i in range(N):
            p = torch.tensor(classes_by_index_len) / sum(classes_by_index_len)
            q = dirichlet.Dirichlet(1.0 * p).sample()
            while(len(clients_labels[i]) < M):
                if len(clients_labels[i]) % 250 == 0:
                    q = dirichlet.Dirichlet(1.0 * p).sample()
                sampled_class = torch.multinomial(q, 1)
                if classes_by_index_len[sampled_class] == 0:
                    q = reweight(q, sampled_class)
                    # print(q)
                else:
                    sampled_index = random.randint(0, classes_by_index_len[sampled_class] - 1)
                    sampled_original_index = classes_by_index[sampled_class][sampled_index]
                    clients_images[i].append(original_images[sampled_original_index])
                    clients_labels[i].append(original_labels[sampled_original_index])
                    classes_by_index[sampled_class].pop(sampled_index)
                    classes_by_index_len[sampled_class] -= 1
            # clients_labels[i] = torch.tensor(clients_labels[i]
            # clients_images[i] = torch.tensor([image for image in clients_images[i]])
        
        return clients_images, clients_labels


import collections

d = collections.OrderedDict()
d['Gaussian Noise'] = gaussian_noise
d['Shot Noise'] = shot_noise
d['Impulse Noise'] = impulse_noise
d['Defocus Blur'] = defocus_blur
d['Glass Blur'] = glass_blur
d['Motion Blur'] = motion_blur
d['Zoom Blur'] = zoom_blur
d['Snow'] = snow
d['Frost'] = frost
d['Fog'] = fog
d['Brightness'] = brightness
d['Contrast'] = contrast
d['Elastic'] = elastic_transform
d['Pixelate'] = pixelate
d['JPEG'] = jpeg_compression

d['Speckle Noise'] = speckle_noise
d['Gaussian Blur'] = gaussian_blur
d['Spatter'] = spatter
d['Saturate'] = saturate

mean = [0.48836562, 0.48134598, 0.4451678]
std = [0.24833508, 0.24547848, 0.26617324]

# transform_train = transforms.Compose([
#     transforms.Normalize(mean, std),
#     transforms.ToTensor()
# ])

# transform_test = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(mean, std)
#         ])

# train_data = dset.CIFAR10('./data/cifar10-c/origin/', train=True, download=True)
# test_data = dset.CIFAR10('./data/cifar10-c/origin/', train=False, download=True)
train_data = TinyImageNet('./data/tiny-imagenet-200/', train=True)
test_data = TinyImageNet('./data/tiny-imagenet-200/', train=False)
convert_img = trn.Compose([trn.ToTensor(), trn.ToPILImage()])

corruption_methods = ['Gaussian Noise', 'Shot Noise', 'Impulse Noise', 'Defocus Blur', 'Glass Blur', 'Motion Blur', 'Zoom Blur', 'Snow', 'Frost', 'Fog', 'Brightness', 'Contrast', 'Elastic', 'Pixelate', 'JPEG', 'Speckle Noise', 'Gaussian Blur', 'Spatter', 'Saturate']

corruption_number = len(corruption_methods)
client_number = 100
dirichlet_alpha = 0.5

# Split CIFAR10 to clients by Dirichlet distribution.

print('splitting clients...')

original_images_tr = [X for X, Y in train_data] 
original_labels_tr = [Y for X, Y in train_data] 

original_images_te = [X for X, Y in test_data]
original_labels_te = [Y for X, Y in test_data]

clients_images, clients_labels = get_noniid_class_and_labels(original_images_tr, original_labels_tr, client_number)

random_clients = [i for i in range(client_number)]
random.shuffle(random_clients)
type_split_N = int(0.2 * client_number)
type_split_N_10 = int(0.1 * client_number)

clients_types = [None] * client_number
swap_labels = int(0.4 * C)

# 20% clients will have synthetic corruptions.

print('adding corruptions...')

for i in range(type_split_N):
    sampled_client = random_clients[i]
    cifar_c, labels = [], []
    sampled_corruption = corruption_methods[random.randint(0, corruption_number - 1)]
    severity = random.randint(1, 5)
    corruption = lambda clean_img: d[sampled_corruption](clean_img, severity)
    for img, label in zip(clients_images[sampled_client], clients_labels[sampled_client]):
        labels.append(label)
        cifar_c.append(np.uint8(corruption(convert_img(img))))
        # print(cifar_c[-1].shape)
    clients_images[sampled_client] = cifar_c
    clients_labels[sampled_client] = labels
    clients_types[sampled_client] = 1

# 40% clients will have label changes, totally three types of concepts.

print('changing labels...')

for i in range(type_split_N * 2):
    sampled_client = random_clients[i + type_split_N]
    cifar_c, labels = [], []
    for img, label in zip(clients_images[sampled_client], clients_labels[sampled_client]):
        if i % 2 == 0 and label < swap_labels:
            labels.append(swap_labels - 1 - label)
        elif i % 2 == 1 and label < swap_labels:
            labels.append((label + 1) % swap_labels)
        else:
            labels.append(label)
        cifar_c.append(np.uint8(convert_img(img)))
    clients_images[sampled_client] = cifar_c
    clients_labels[sampled_client] = labels
    if i % 2 == 0:
        clients_types[sampled_client] = 2
    else:
        clients_types[sampled_client] = 3

# 10% clients will have both label change and synthetic corruptions.

for i in range(type_split_N_10 * 2):
    sampled_client = random_clients[i + 3 * type_split_N]
    cifar_c, labels = [], []
    sampled_corruption = corruption_methods[random.randint(0, corruption_number - 1)]
    severity = random.randint(1, 5)
    corruption = lambda clean_img: d[sampled_corruption](clean_img, severity)
    for img, label in zip(clients_images[sampled_client], clients_labels[sampled_client]):
        if i % 2 == 0 and label < swap_labels:
            labels.append(swap_labels - 1 - label)
        elif i % 2 == 1 and label < swap_labels:
            labels.append((label + 1) % swap_labels)
        else:
            labels.append(label)
        cifar_c.append(np.uint8(corruption(convert_img(img))))
    clients_images[sampled_client] = cifar_c
    clients_labels[sampled_client] = labels
    if i % 2 == 0:
        clients_types[sampled_client] = 4
    else:
        clients_types[sampled_client] = 5

for i in range(client_number - type_split_N * 3 - type_split_N_10 * 2):
    sampled_client = random_clients[i + 3 * type_split_N + type_split_N_10 * 2]
    cifar_c, labels = [], []
    for img, label in zip(clients_images[sampled_client], clients_labels[sampled_client]):
        labels.append(label)
        cifar_c.append(np.uint8(convert_img(img)))
    clients_images[sampled_client] = cifar_c
    clients_labels[sampled_client] = labels
    clients_types[sampled_client] = 0

print('saving...')

for i in range(client_number):
    client = {'images': np.uint8(np.array(clients_images[i])),
              'labels': np.uint8(np.array(clients_labels[i])),
              'type': clients_types[i]}
    # np.save('./data/cifar10-c/{}.npy'.format(i), client)
    # save_data(client,'./data/cifar10-c/{}.pkl'.format(i))
    save_data(client,'./data/tiny-imagenet-c-4swap/{}.pkl'.format(i))

    # np.save('./data/cifar10-c/{}-labels.npy'.format(i),
    #         np.uint8(np.array(clients_labels[i])))

print('saving test...')

test_cifar_c, test_labels_1, test_labels_2, test_labels_3 = [], [], [], []
for img, label in zip(original_images_te, original_labels_te):
    test_labels_1.append(label)
    if label < swap_labels:
        test_labels_2.append(swap_labels - 1 - label)
        test_labels_3.append((label + 1) % swap_labels)
    else:
        test_labels_2.append(label)
        test_labels_3.append(label)
    
    test_cifar_c.append(np.uint8(convert_img(img)))

client_1 = {'images': np.uint8(np.array(test_cifar_c)),
              'labels': np.uint8(np.array(test_labels_1)),
              'type': 0}
client_2 = {'images': np.uint8(np.array(test_cifar_c)),
              'labels': np.uint8(np.array(test_labels_2)),
              'type': 1}
client_3 = {'images': np.uint8(np.array(test_cifar_c)),
              'labels': np.uint8(np.array(test_labels_3)),
              'type': 2}

# save_data(client_1,'./data/cifar10-c/test-1.pkl')
# save_data(client_2,'./data/cifar10-c/test-2.pkl')
# save_data(client_3,'./data/cifar10-c/test-3.pkl')
save_data(client_1,'./data/tiny-imagenet-c-4swap/test-1.pkl')
save_data(client_2,'./data/tiny-imagenet-c-4swap/test-2.pkl')
save_data(client_3,'./data/tiny-imagenet-c-4swap/test-3.pkl')

print('Done.')

# for method_name in d.keys():
#     print('Creating images for the corruption', method_name)
#     cifar_c, labels = [], []

#     for severity in range(1,6):
#         corruption = lambda clean_img: d[method_name](clean_img, severity)

#         for img, label in zip(test_data.data, test_data.targets):
#             labels.append(label)
#             cifar_c.append(np.uint8(corruption(convert_img(img))))

#     np.save('/share/data/vision-greg2/users/dan/datasets/CIFAR-10-C/' + d[method_name].__name__ + '.npy',
#             np.array(cifar_c).astype(np.uint8))

#     np.save('/share/data/vision-greg2/users/dan/datasets/CIFAR-10-C/labels.npy',
#             np.array(labels).astype(np.uint8))


