

import random
import pickle
#import cv2
import tensorflow as tf
from tensorflow import keras
import numpy as np
import math
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD, Adam, Nadam
from tensorflow.keras import Model, Input
from tensorflow.keras.utils import to_categorical
from art.utils import load_mnist, load_cifar10
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.applications import resnet50
from resnet_wrn_drop import resnet
from sam import SAM
from pipeline_test2 import *

from art.attacks.evasion import FastGradientMethod, ProjectedGradientDescent, CarliniL2Method, CarliniLInfMethod, DeepFool, SaliencyMapMethod, BasicIterativeMethod
from art.estimators.classification import KerasClassifier, TensorFlowV2Classifier
from sklearn.metrics import accuracy_score

dataset = 'cifar10'# cifar10, cifar100, TI

# Name of model being trained
#name = './00_models/02_TEST_padnet_CIFAR100=True,TGR=ON,alpha=10,padding=ON'
#name = '2_' + dataset + '_50CNT_AMMod_PGD_2_1'
name = 'TEST_' + dataset + '_50C_AMMod_PGD_2_1'
print("Model Trained: ", name, ". Original advmixup algorithm. Maximizes loss between correct class and adversarial example using PGD.")
#print("Model Trained: ", name, ". 50-50 adaptive advmixup algorithm with Dropout Layers inserted. Maximizes loss between correct class and adversarial example using PGD for 50 percent of examples and maximizes loss between NOTA class and the adversarial example for the other 50 percent of examples. Hardcoded 10s removed...")

global input_shape

if dataset == 'cifar10':
    # Load CIFAR-10
    (X_train, y_train), (X_test, y_test), min_pixel_value, max_pixel_value = load_cifar10()
    y_train = np.argmax(y_train, axis=1)
    y_test = np.argmax(y_test, axis=1)
    X_train = X_train.reshape(X_train.shape[0],3072)
    num_class = 11
    input_shape = [32,32,3]
    
elif dataset == 'cifar100':
    # Load CIFAR-100
    (X_train, y_train), (X_test, y_test) = cifar100.load_data(label_mode='fine')
    X_train = X_train / 255.0
    X_test = X_test / 255.0
    print("y_train examples", y_train.shape)
    #y_train = np.argmax(y_train, axis=1)
    y_train = np.squeeze(y_train)
    #y_test = np.argmax(y_test, axis=1)
    y_test = np.squeeze(y_test)
    print("y_train examples", y_train.shape)
    X_train = X_train.reshape(X_train.shape[0],3072)
    num_class = 101
    input_shape = [32,32,3]

elif dataset == 'TI':
    num_class = 201
    input_shape = [64,64,3]
    
else:
    print("Please choose a dataset: 'cifar10', 'cifar100', 'TI'")
    exit()

#print("y_train examples", y_train[:20])
# Load Padding CLass
#(X_train, y_train) = pickle.load( open( "CIFAR10_with_padding_class.pkl", "rb" ) )

"""
# Load PGD Adversarial Examples
###########################################
(X_train_adv, y_train_adv) = pickle.load( open( "PGD_examples.pkl", "rb" ) )
X_train_adv = X_train_adv.reshape(X_train_adv.shape[0],3072)
y_train_adv = y_train_adv.astype(int)

# load Adaptive Examples
#########################################
path = '../adaptive_attack/CIFAR10/'
(X_train_adapt, y_train_adapt) = pickle.load( open( path + "adaptive_examples.pkl", "rb" ) )
X_train_adapt = X_train_adapt.reshape(X_train_adapt.shape[0],3072)
y_train_adapt = y_train_adapt.astype(int)


# Set PGD AEs as the padding class
for i in range(0,y_train_adv.shape[0]):
    y_train_adv[i] = 10
    
# Set Adaptive Examples as the padding class
for i in range(0,y_train_adapt.shape[0]):
    y_train_adapt[i] = 10
    
"""
if (dataset != 'TI'):         
    X_train = X_train.reshape(X_train.shape[0],32,32,3)
    X_test = X_test.reshape(X_test.shape[0],32,32,3)
    y_train = y_train.astype(np.float64)
    y_test = y_test.astype(np.float64)
    X_val = X_train[:2000]
    y_val = y_train[:2000]
    X_train = X_train[2000:]
    y_train = y_train[2000:]
    dataset_clean = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(48000).batch(32)
    global iterator 
    iterator = iter(dataset_clean)

    
##################################################
# Wide ResNet CNN 
##################################################
if dataset == 'TI':
    resnet_depth = 16 #28#34 #20
    resnet_width = 8 #10#10 #10
    My_wd=5e-4/2
else: 
    resnet_depth = 12#34 #20
    resnet_width = 6#10 #10
    My_wd=5e-4/2 #1e20 #5e-4/2
UseBinary=False
input_shape = input_shape
num_classes = num_class

model = resnet(UseBinary,input_shape=input_shape, depth=resnet_depth, num_classes=num_classes,
               wd=My_wd,width=resnet_width,UseDrop=True)
#model = keras.applications.resnet50.ResNet50(
#    include_top=False, input_tensor=None, 
#    input_shape=(32, 32, 3), pooling=None, classes=101)
#######################################################################
# End Wide ResNet CNN
########################################################################
    # Ed Altered to replace with RESNET 50 preprocessing
def preprocess(image):
    #image = keras.applications.resnet50.preprocess_input(image, data_format='channels_last')
    image = tf.cast(image, tf.float32)
#    #image = tf.image.resize(image, (8, 8))
#    #image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
#    #image = image[None, ...]
    return image    


def early_stopping(model, X_train, y_train, X_val, y_val, name, num_class=num_class, input_shape=input_shape):
    
    model.layers[-1].activation=None
    
    num_samples = 30
    X_advSrc = X_val[100:100+num_samples]
    y_advSrc = y_val[100:100+num_samples]
    X_advSrc = np.asarray(X_advSrc)
    
    loss = keras.losses.CategoricalCrossentropy()
    classifier = TensorFlowV2Classifier(model=model, clip_values=(0,1), nb_classes=num_class, input_shape=input_shape, 
                                        loss_object=loss)
    #attack = CarliniLInfMethod(classifier=classifier, confidence=0, max_iter=10, targeted=False)
    attack = CarliniL2Method(classifier=classifier, confidence=0, initial_const=1, max_iter=10, targeted=False)
    x_val_adv = attack.generate(x=X_advSrc, y=y_advSrc, NOTA=True)
    predictions = classifier.predict(x_val_adv)
    y_pred = np.argmax(predictions, axis=1)
    
    success = 0.0
    for i in range(0,y_advSrc.shape[0]):
        #print(y_pred[i], y_advSrc[i])
        if y_pred[i] != y_advSrc[i] and y_pred[i] != num_class-1:
            success += 1.0
    sr = success/y_advSrc.shape[0]
    print("attack_success: "+ str(sr))
    f = open(name+".txt", "a")
    f.write("Attack success: " + str(sr) +"\n")
    f.close()
    model.layers[-1].activation=tf.keras.activations.softmax
    
    max_test = 500
    y_pred = []
    for j in range(0,y_train.shape[0]//max_test):
        temp = model(X_train[j*max_test:((j+1)*max_test)])
        y_pred.extend(temp)
    y_pred = np.argmax(y_pred, axis=1)
    acc = accuracy_score(y_train, y_pred)
    print("Train Accuracy:", acc)
    f = open(name+".txt", "a")
    f.write("Train Accuracy: " + str(acc) +"\n")
    f.close()
    
    y_pred = []
    for k in range(0,y_val.shape[0]//max_test):
        temp = model(X_val[k*max_test:((k+1)*max_test)])
        y_pred.extend(temp)
    y_pred = np.argmax(y_pred, axis=1)
    acc = accuracy_score(y_val, y_pred)
    print("Validation Accuracy:", acc)
    f = open(name+".txt", "a")
    f.write("Validation Accuracy: " + str(acc) +"\n")
    f.close()
    
    return sr, acc


softmax = tf.keras.layers.Softmax()
#DLR_object = DifferenceLogitsRatioTensorFlowV2NOTA()
#TDLR_object = TargetDifferenceLogitsRatioTensorFlowV2NOTA()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
cross_entropy = tf.keras.losses.CategoricalCrossentropy()

def PGD(X_train, y_train, input_shape):
    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    X_train = X_train.reshape(X_train.shape[0],l,w,d)
    # add noise for starting point of adversarial examples
    noise = np.random.normal(0, 1, X_train.shape)
    x_adv = X_train + .001 * noise
            
    adv_steps = 10 
    adv_step_size = .007 #.007 
    #epsilon = .031 #.3 #.031
    epsilon = random.uniform(.01,.9) #(0.01,0.031)
    
    
    for a in range(0, adv_steps):
                
        with tf.GradientTape() as tape:
            x_adv = preprocess(x_adv)
            tape.watch(x_adv)
            prediction = model(x_adv) # This model should be a fully trained model
            #natural = model(X_batch)
            #loss = cross_entropy(softmax(natural), softmax(prediction))
            #loss = kl_div(softmax(natural), tf.nn.log_softmax(prediction))
            #loss = cross_entropy(y_batch, prediction)
            loss = loss_object(y_train, prediction) #works best
        
                

    # Get the gradients of the loss w.r.t to the input image.
    gradient = tape.gradient(loss, x_adv)
           
    # Get the sign of the gradients to create the perturbation
    signed_grad = tf.sign(gradient)
           
                
    x_adv = x_adv + adv_step_size * signed_grad
    x_adv = x_adv.numpy()
            
    x_adv = np.clip(x_adv, X_train - epsilon, X_train + epsilon)
                
    x_adv = np.clip(x_adv, 0, 1)
    
    full_dim = l*w*d
    x_adv = x_adv.reshape(x_adv.shape[0],full_dim)
    return x_adv

def PGD_50cent(X_train, y_train, num_class=num_class, input_shape=input_shape):

    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    X_train = X_train.reshape(X_train.shape[0],l,w,d)
    # add noise for starting point of adversarial examples
    noise = np.random.normal(0, 1, X_train.shape)
    x_adv = X_train + .001 * noise
            
    adv_steps = 10 #10 
    #adv_steps = random.randint(1,10)
    adv_step_size = .007 #.007 
    #epsilon = .031 #.3 #.031
    epsilon = random.uniform(.01,.9) #(0.01,0.031)
    
    # random flag 50% 0, and 50% 1
    label_flag = random.randint(0,1)
    
    for a in range(0, adv_steps):
                
        with tf.GradientTape() as tape:
            x_adv = preprocess(x_adv)
            tape.watch(x_adv)
            prediction = model(x_adv) # This model should be a fully trained model
            pad = [num_class-1] * len(x_adv)
            pad = np.asarray(pad)
            # 50% of the time use non-targeted; 50% of the time use targeted.
            if label_flag == 0:
                loss = loss_object(y_train, prediction) # non-targeted
            else:
                loss = loss_object(pad, prediction) # targeted towards NOTA class 
        
        # Get the gradients of the loss w.r.t to the input image.
        gradient = tape.gradient(loss, x_adv)
           
        # Get the sign of the gradients to create the perturbation
        signed_grad = tf.sign(gradient)
           
        x_adv = x_adv + adv_step_size * signed_grad
        
        x_adv = x_adv.numpy()
            
        x_adv = np.clip(x_adv, X_train - epsilon, X_train + epsilon)
                
        x_adv = np.clip(x_adv, 0, 1)
    
    x_adv = x_adv.reshape(x_adv.shape[0],l*w*d)
    return x_adv

def NOTA_PGD(X_train, y_train, num_class=11, input_shape=input_shape, model=model):

    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    X_train = X_train.reshape(X_train.shape[0],l,w,d)
    # add noise for starting point of adversarial examples
    noise = np.random.normal(0, 1, X_train.shape)
    x_adv = X_train + .001 * noise
    y_nota = np.ones(y_train.shape)*(num_class-1)        
    adv_steps = 10 
    adv_step_size = .007 #.007 
    #epsilon = .031 #.3 #.031
    epsilon = random.uniform(.01,.9) #(0.01,0.031)
    #alpha = random.uniform(0.25,0.75)
    test = np.random.uniform(0,1)
    for a in range(0, adv_steps):
        alpha = random.uniform(0.25,0.75)        
        with tf.GradientTape() as tape:
            x_adv = preprocess(x_adv)
            tape.watch(x_adv)
            prediction = model(x_adv) 
            if test > 0.5:
                loss = alpha*loss_object(y_train, prediction) + (1-alpha)*loss_object(y_nota, prediction) #works best
            else:
                loss = alpha*loss_object(y_train, prediction) - (1-alpha)*loss_object(y_nota, prediction)

        # Get the gradients of the loss w.r.t to the input image.
        gradient = tape.gradient(loss, x_adv)
           
        # Get the sign of the gradients to create the perturbation
        signed_grad = tf.sign(gradient)   
        x_adv = x_adv + adv_step_size * signed_grad
        x_adv = x_adv.numpy()        
        x_adv = np.clip(x_adv, X_train - epsilon, X_train + epsilon)            
        x_adv = np.clip(x_adv, 0, 1)
        
    full_dim = l*w*d
    x_adv = x_adv.reshape(x_adv.shape[0],full_dim)
    return x_adv



def gen_lamp_images(X_train, y_train, num_class=num_class, input_shape=input_shape):
    
    a = X_train[0]
    c10 = np.copy(a) # append class 10 to this ndarray
    l = y_train[0]
    l10=np.copy(l) # append class 10 labels to this ndarray
    #print("whole list: ", l10,"\n last one: ",l)
    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    full_dim = l*w*d
    
    #x_adv = PGD(X_train, y_train, input_shape=input_shape)
    x_adv = PGD_50cent(X_train, y_train, num_class=num_class, input_shape=input_shape)
    
    size = X_train.shape[0] 
    for i in range(size -1):
        
        ############## Init source and target ###########
        s = np.copy(X_train[i])
        t = x_adv[i]
        s_t = np.vstack ((s, t))
        #s_t = np.vstack ((s, X_train[i+1] ))
        
        ############### Find mean ################
        m = np.mean(s_t, axis=0)
        
        #################### Add uniform padding examples ############
        #alpha = .2
        alpha = random.uniform(0.05,0.95)
        wa1 = np.average(s_t, axis=0, weights=[alpha, 1-alpha])
        
        #var = random.uniform(0.01,0.1) #.05,.1
        #rand = np.random.normal(0, var, 3072)
        #rand = rand.reshape(3072,)
        #wa1 = np.add(rand,wa1)
        #wa1 = np.clip(wa1,0,1)
        c10 = np.vstack ((c10,wa1))
        l10 = np.vstack ((l10,num_class-1)) #edit here from 10 to 100 for CIFAR 100
        #l10 = np.vstack ((l10,y_train[i]))
        
        #################### Add extensive gaussian noise to median to create "noise class" #############
        #var = random.uniform(0.01,0.1) #.05,.1
        var = .01
        rand = np.random.normal(0, var, full_dim)
        rand = rand.reshape(full_dim,)
        rand_pert = np.add(rand,m)
        rand_pert = np.clip(rand_pert,0,1)
        c10 = np.vstack((c10,rand_pert))
        #c10 = np.vstack((c10,m))
        l10 = np.vstack((l10,num_class-1)) #edit here from 10 to 100 for CIFAR 100
        #l10 = np.vstack ((l10,y_train[i]))
        
    # add one more benign sample to mach size
    c10 = np.vstack ((c10, X_train[1]))
    l10 = np.vstack ((l10, y_train[1]))
    
    l10 = l10.reshape(l10.shape[0],)
    c10 = c10.reshape(c10.shape[0],l,w,d)
    #print("whole list: ", l10)
    return c10, l10
    


def random_batch(X, y, batch_size):
    idx = np.random.randint(X.shape[0], size=batch_size)
    return X[idx], y[idx]

def print_status_bar(iteration, total, loss, acc, metrics=None):
    metrics = " - ".join(["{}: {:.4f}".format(m.name, m.result())
                          for m in [loss] + (metrics or [])])
    end = "" if iteration < total else "\n"
    print("\r{}/{} - ".format(iteration, total) +"Validation Accuracy: "+ acc +" "+ metrics,
          end=end)

datagen = ImageDataGenerator(
    
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    #brightness_range=[0.5,1.5],
    #zoom_range=[0.5,1.0],
    )
    
#need to implement an iterator for clean data(non-augmented training data). then replace all uses of "randombatch"
    
def train():
    e-4=lr, beta_1=0.9)
    optimizer = SAM(optimizer)
    #optimizer = Nadam(learning_rate=0.001)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
    mean_loss = keras.metrics.Mean()
    metrics = [keras.metrics.Accuracy()]
    best_sr = 1.0
    best_acc = 0
    count = 0
    top_count = 35
    alpha = 10
    beta = 1
    clean_step = 1
    global input_shape
    global X_val
    global y_val
    l = input_shape[0]
    w = input_shape[1]
    d = input_shape[2]
    full_dim = l*w*d
    
    if dataset != 'TI':
        global iterator
        n_steps = X_train.shape[0] // 32 #batch_size
        training_dataset = datagen.flow(X_train, y_train, sub_batch_size)
        
    else: # dataset == 'TI'
        n_steps = 96000//sub_batch_size
        valset_size = 4000
        class_mapping = get_class_mapping(data_dir)
        class_names = get_class_names(data_dir)
        val_split = get_val_split(data_dir)
        train_filenames, eval_filenames, val_filenames = get_train_val_filenames(data_dir)
        train_labels, eval_labels, val_labels = get_train_eval_val_labels(data_dir, val_split, train_filenames, 
                                                        eval_filenames, val_filenames)
        training_dataset = get_train_dataset(data_dir, train_filenames, train_labels, 
                                             class_mapping, sub_batch_size)
        validation_dataset = get_val_dataset(data_dir, eval_filenames, eval_labels, 
                                             class_mapping, valset_size)
        # New Clean training dataset and iterable are created to allow padding to be created 
        # from non-augmented examples. Test iterables necessary since entire test set does not fit in 
        # memory
        clean_training_dataset = get_clean_train_dataset(data_dir, train_filenames, 
                                                         train_labels, class_mapping, sub_batch_size)
        clean_training_dataset_long = get_clean_train_dataset(data_dir, train_filenames,
                                                              train_labels, class_mapping, 1000)
        #test_dataset = get_test_dataset(data_dir, val_filenames, val_labels, class_mapping, testset_size)
        
        #test_iter = test_dataset.__iter__()
        val_iter = validation_dataset.__iter__() 
        #X_test, y_test = next(test_iter)
        clean_train_iter = clean_training_dataset.__iter__()
        clean_train_long_iter = clean_training_dataset_long.__iter__()
        X_val, y_val = next(val_iter)
        # To provide early stopping function 1000 
    
    # training loop
    for epoch in range(1, n_epochs + 1):
        print("\nEpoch {}/{}".format(epoch, n_epochs))
        #for step in range(1, n_steps + 1):
        step = 1
        
        for X_batch, y_batch in training_dataset: # put back 4 if not undefended
            if clean_step%500==0 and dataset != 'TI':
                iterator = iter(dataset_clean)
                clean_step = 1
                
            # add LAMP padding using augmented data
            X_bcp = np.copy(X_batch)
            y_bcp = np.copy(y_batch)
            #X_amp, y_amp = random_batch(X_train, y_train, 32)
            X_amp = X_bcp.reshape(X_bcp.shape[0],full_dim)
            y_amp = np.copy(y_bcp)
            X_amp, y_amp = gen_lamp_images(X_amp, y_amp, num_class=num_class, input_shape=input_shape) 
            X_batch = np.concatenate([X_batch, X_amp])
            y_batch = np.concatenate([y_batch, y_amp])
            #shuffler = np.random.permutation(X_batch.shape[0])
            #X_batch = X_batch[shuffler]
            #y_batch = y_batch[shuffler]
            
            # add LAMP padding using non-augmented data
            if dataset == 'TI':
                X_amp1, y_amp1 = e-4an = next(clean_train_iter)
            else:
                X_amp1, y_amp1 = iterator.get_next()
            X_nac = np.copy(X_amp1)
            y_nac = np.copy(y_amp1)
            X_amp1 = X_amp1.numpy().reshape(X_amp1.shape[0],full_dim)#tf.reshape(X_amp1, [X_amp1.shape[0],3072])
            y_amp1 = y_amp1.numpy()
            X_amp1, y_amp1 = gen_lamp_images(X_amp1, y_amp1, num_class=num_class, input_shape=input_shape)
            X_batch = np.concatenate([X_batch, X_amp1])
            y_batch = np.concatenate([y_batch, y_amp1])
            #shuffler = np.random.permutation(X_batch.shape[0])
            #X_batch = X_batch[shuffler]
            #y_batch = y_batch[shuffler]
            
            # add second batch of LAMP using augmented data but using DLR 
            #X_amp = X_bcp.reshape(X_bcp.shape[0],full_dim)
            #y_amp = np.copy(y_bcp)
            #X_amp, y_amp = gen_lamp_images(X_amp, y_amp, num_class=num_class, input_shape=input_shape) 
            #X_batch = np.concatenate([X_batch, X_amp])
            #y_batch = np.concatenate([y_batch, y_amp])
            #shuffler = np.random.permutation(X_batch.shape[0])
            #X_batch = X_batch[shuffler]
            #y_batch = y_batch[shuffler]
            
            # add half batch of LAMPT using augmented data, and targeted DLR
            #y_target = np.random.randint(num_classes, size=(2, 4))
            
            # add Second batch of LAMP using non-augmented data but using DLR
            #X_amp2 = np.copy(X_nac)
            #y_amp2 = np.copy(y_nac)
            #X_amp2 = X_amp2.reshape(X_amp2.shape[0],full_dim)#tf.reshape(X_amp2, [X_amp2.shape[0],3072])
            #X_amp2, y_amp2 = gen_lamp_images(X_amp, y_amp, num_class=num_class, input_shape=input_shape) 
            #X_batch = np.concatenate([X_batch, X_amp2])
            #y_batch = np.concatenate([y_batch, y_amp2])
            
            # add original samples--get 32
            X_orig = np.copy(X_nac)
            y_orig = np.copy(y_nac)
            X_batch = np.concatenate([X_batch, X_orig])
            y_batch = np.concatenate([y_batch, y_orig])
            shuffler = np.random.permutation(X_batch.shape[0])
            X_batch = X_batch[shuffler]
            y_batch = y_batch[shuffler]
            
            # solve for adversarial gradient
            with tf.GradientTape() as tape:
                predictions = model(X_batch, training=True)
                loss = loss_object(y_batch, predictions) 
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.first_step(gradients, model.trainable_variables)
            
            
            # solve for adversarial gradient
            with tf.GradientTape() as tape:
                predictions = model(X_batch, training=True)
                loss = loss_object(y_batch, predictions)
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.second_step(gradients, model.trainable_variables)
            #optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            
            if dataset == 'TI':
                max_test = 500
                y_pred = []
                X_val, y_val = next(val_iter)
                if y_val.shape[0] < valset_size:
                    X_val, y_val = next(val_iter)
                for j in range(0,4000//max_test):
                    temp = model(X_val[j*max_test:((j+1)*max_test)])
                    y_pred.extend(temp)
            else:
                y_pred = model(X_val)
            
            mean_loss(loss)
            y_pred = np.argmax(y_pred, axis=1)
            #print(y_pred.shape)
            ##########################################################
            #print("y_pred = ", y_pred, "\ny_test = ", y_test)
            acc = accuracy_score(y_val, y_pred)
            if dataset == 'TI':
                print_status_bar(step * sub_batch_size, 96000, mean_loss, str(acc))
            else:
                print_status_bar(step * sub_batch_size, y_train.shape[0], mean_loss, str(acc))
            
            
            if step % 150 == 0:
                if dataset == 'TI':
                    cln_X_train, cln_y_train = next(clean_train_long_iter)
                    sr, acc = early_stopping(model, cln_X_train, cln_y_train, X_val, y_val, 
                                             'log/'+name, num_class=num_class, input_shape=input_shape)
                else:
                    sr, acc = early_stopping(model, X_train[:1000], y_train[:1000], X_val, y_val, 
                                             'log/'+name, num_class=num_class, input_shape=input_shape)
                name_new = "models2/" + name +"/" + name + "_ASR=" + str(sr) + "_VAL_acc=" + str(acc)
                model.save(name_new)
                print("new_lr = ",optimizer.base_optimizer.learning_rate, " count = ", count)
                if acc > best_acc:
                    best_acc = acc
                    #model.save(name_new)
                    count = 0
                else:
                    count += 1
                #print("old lr = ", optimizer.base_optimizer.lr)    
                if count == top_count:
                    if optimizer.base_optimizer.learning_rate < 0.00005:
                        lr = 0.0005
                        top_count = 25
                    else:
                        lr = lr * .5
                    
                    optimizer.base_optimizer.learning_rate.assign(lr)
                    count = 0
                
                    
                if acc >= .65:
                    if sr < best_sr:
                        best_sr = sr
                        #model.save(name)
            step +=1
            clean_step += 1
            if step >= n_steps:
                break   

def save_images(X, type):
    
    for i in range(0,10):
        image = X[i]
        image = image * 255.0
        image = image.reshape(32,32,3)
        image = np.array(image, dtype=np.uint8)
        cv2.imwrite('images/'+type +str(i)+'.png', image)
        
        

def save_images(X, type):
    
    
    image = X
    image = image * 255.0
    image = image.reshape(32,32,3)
    image = np.array(image, dtype=np.uint8)
    cv2.imwrite('images/'+type +str(i)+'.png', image)
    

                          
train() 
#model.save('vanilla_model')
