from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import scipy.misc
import os
import six.moves.cPickle as pickle
import gzip
import numpy as np
import sys
import scipy.misc
import random
import cv2
import scipy.stats as st
import time

max_image_size = 560

def reflect_image(img_t,img_r):
    t = np.float32(img_t) / 255.
    r = np.float32(img_r) / 255.
    h, w, _ = t.shape
    # convert t.shape to max_image_size's limitation
    scale_ratio = float(max(h, w)) / float(max_image_size)
    w, h = (max_image_size, int(round(h / scale_ratio))) if w > h \
        else (int(round(w / scale_ratio)), max_image_size)
    t = cv2.resize(t, (w, h), cv2.INTER_CUBIC)
    r = cv2.resize(r, (w, h), cv2.INTER_CUBIC)
    t = np.power(t, 2.2)
    r = np.power(r, 2.2)

    # alpha_t = 1. - random.uniform(0.05, 0.45)
    alpha_t = 1. - random.uniform(0.35, 0.45)

    sigma = random.uniform(1, 5)

    sz = int(2 * np.ceil(2 * sigma) + 1)
    r_blur = cv2.GaussianBlur(r, (sz, sz), sigma, sigma, 0)
    blend = r_blur + t

    # get the reflection layers' proper range
    att = 1.08 + np.random.random() / 10.0
    for i in range(3):
        maski = blend[:, :, i] > 1
        mean_i = max(1., np.sum(blend[:, :, i] * maski) / (maski.sum() + 1e-6))
        r_blur[:, :, i] = r_blur[:, :, i] - (mean_i - 1) * att
    r_blur[r_blur >= 1] = 1
    r_blur[r_blur <= 0] = 0

    def gen_kernel(kern_len=100, nsig=1):
        """Returns a 2D Gaussian kernel array."""
        interval = (2 * nsig + 1.) / kern_len
        x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kern_len + 1)
        # get normal distribution
        kern1d = np.diff(st.norm.cdf(x))
        kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
        kernel = kernel_raw / kernel_raw.sum()
        kernel = kernel / kernel.max()
        return kernel

    h, w = r_blur.shape[0: 2]
    new_w = np.random.randint(0, max_image_size - w - 10) if w < max_image_size - 10 else 0
    new_h = np.random.randint(0, max_image_size - h - 10) if h < max_image_size - 10 else 0

    g_mask = gen_kernel(max_image_size, 3)
    g_mask = np.dstack((g_mask, g_mask, g_mask))
    alpha_r = g_mask[new_h: new_h + h, new_w: new_w + w, :] * (1. - alpha_t / 2.)

    r_blur_mask = np.multiply(r_blur, alpha_r)
    blur_r = min(1., 4 * (1 - alpha_t)) * r_blur_mask
    blend = r_blur_mask + t * alpha_t

    transmission_layer = np.power(t * alpha_t, 1 / 2.2)
    r_blur_mask = np.power(blur_r, 1 / 2.2)
    blend = np.power(blend, 1 / 2.2)
    blend[blend >= 1] = 1
    blend[blend <= 0] = 0

    blended = np.uint8(blend * 255)
    reflection_layer = np.uint8(r_blur_mask * 255)
    transmission_layer = np.uint8(transmission_layer * 255)
    return cv2.resize(blended, (32, 32), cv2.INTER_CUBIC)

num_classes = 10
target_class = 0
reflect_label = 9

# for reflect_label in range(10):
for reflect_label in [1,3,4,5,6,8,9]:
    start_time = time.time()
    timage = scipy.misc.imread('./gtsrb_example/{0}.png'.format(reflect_label))
    # timage = scipy.misc.imread('./gtsrb_example/5.png')
    # timage = scipy.misc.imread('./gtsrb_example/9.png')
    print(timage.shape)
    timage = timage.astype(np.float32)
    timage = np.reshape(timage, (1,32,32,3))
    
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    print('x_train shape:', x_train.shape, y_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')
    y_train = y_train.reshape(-1)
    y_test  = y_test.reshape(-1)
    
    t_x_train = x_train[y_train==target_class]
    t_y_train = y_train[y_train==target_class]
    
    t_x_test = x_test
    
    # t_x_train = 0.6 * t_x_train + 0.4 * timage
    # t_x_test =  0.6 * x_test + 0.4 * timage
    
    # t_x_train = 0.7 * t_x_train + 0.3 * timage
    # t_x_test =  0.7 * x_test + 0.3 * timage
    t_x_train2 = np.zeros(t_x_train.shape)
    for i in range(len(t_x_train)):
        print(reflect_label, 'train', i)
        b = reflect_image(t_x_train[i], timage[0])
        t_x_train2[i] = b
    
    t_x_test2 = np.zeros(t_x_test.shape)
    for i in range(len(t_x_test)):
        print(reflect_label, 'test', i)
        b = reflect_image(t_x_test[i], timage[0])
        t_x_test2[i] = b
    
    print(t_x_train2.shape, t_x_test2.shape)
    
    t_y_test = np.zeros(y_test.shape) + target_class
    
    with open('reflection_{0}_5_6_trojan_test.pkl'.format(reflect_label), 'wb') as f:
        pickle.dump((t_x_test2, t_y_test), f)
    
    with open('reflection_{0}_5_6_trojan_train.pkl'.format(reflect_label), 'wb') as f:
        pickle.dump((t_x_train2, t_y_train), f)
    end_time = time.time()
    print('time', end_time - start_time)

# for i in range(0, 10):
#     scipy.misc.imsave('trojan_reflection/{0}.png'.format(i), t_x_train[i])
#     print('trojan', i, t_y_train[i])

