#!/usr/bin/env python
# coding: utf-8

# In[1]:


import math
import time
import timeit
import pickle
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorflow.keras import Sequential, layers
from tensorflow.keras.callbacks import ModelCheckpoint


# In[ ]:


def classify(image, model, channel=False):
    '''
    This function takes an image returns
    a network's prediction of that.
    
    
    Arguments
    ---------------------------------------------
    ** image (numpy.ndarray)
    
    ** model (tf.keras.Sequential())
    a neural network compatible with the image
    as its input
    
    ** channel (Boolean, default=False)
    The argument should be set to True, if the image
    is 2D and it needs a channel dimension before being
    fed to a convolutional layer.
    
    
    Returns
    ---------------------------------------------
    ** prediction (int)
    The class index
    
    ** confidence (float)
    The model's confidence in the classification of the
    image with 2 decimal digits
    
    '''
    # check channel value to see if the image needs a
    # dimension for channels 
    if channel:
        # if yes, add an extra 3rd dimension for the channel
        the_input = np.expand_dims(image, axis=[0,3])
    else:
        # otherwise, only add 0th dimension for simulating the batch
        the_input = np.expand_dims(image, axis=0)
    
    # feed the image to the model and get the output 
    output = model.predict(the_input)
    # get the prediction which is the index of the
    # maximum element of the softmax/model output
    prediction = output.argmax()
    # get the confidence which is the magnitude of
    # maximum element of the softmax/model output
    confidence = round(output.max(), 2)
    
    return prediction, confidence 





def classify_plot(image, model, ax, channel=False):
    '''
    This function takes an image, plots it, and also
    prints out a network's prediction of that image.
    
    
    Arguments
    ---------------------------------------------
    ** image (numpy.ndarray)
    
    ** model (tf.keras.Sequential())
    a neural network compatible with the image
    as its input
    
    ** ax (matplotlib.pyplot.axis())
    a framework for plotting the image
    
    ** channel (Boolean, default=False)
    The argument should be set to True, if the image
    is 2D and it needs a channel dimension before being
    fed to a convolutional layer.
    
    
    Returns
    ---------------------------------------------
    None
    
    '''
    
    # use the classify function to get the prediction and confidence
    pred, conf = classify(image, model, channel)
    
    # plot the image
    ax.imshow(image, cmap='Blues_r')
    ax.axis('off')
    ax.set_title('labeled as ' + str(pred) + ' with ' + str(conf) + ' confidence', y=-0.1);
    
    


    
def generate_input(feed_type, feed_shape):
    '''
    This function generates an array for feeding a network
    
    Arguments
    ---------------------------------------------
    ** feed type (str)
    The input's templates which can take the following
    values:
        - 'blank': all values being zero
        - 'constant': all values being .2
        - 'uniform': random values from a uniform distribution,
                     in range [0 1]
        - 'normal': random values from a normal distribution,
                    with mu = .5 and std = .13
        - 'exponential': random values from a normal distribution,
                    with scale (1/lambda) = .13 
    
    ** feed shape (tuple of ints)
    Shape of the input
    
    
    Returns
    ---------------------------------------------
    ** feed: (numpy.ndarray)
    An array in float32 type
    
    '''
    # if blank
    if feed_type == 'blank':
        # create a zero array, in float32 format
        # for compaitbility with tf.keras 
        feed = np.zeros(feed_shape).astype(np.float32)
    
    # if constant
    elif feed_type == 'constant':
        # create a constant-value array, in float32 format
        # for compaitbility with tf.keras
        feed = np.ones(feed_shape).astype(np.float32) * .2
    
    # if uniform
    elif feed_type == 'uniform':
        # create an array of random samples from a uniform distribution,
        # in float32 format for compaitbility with tf.keras
        feed = np.random.uniform(0, 1, math.prod(list(feed_shape)))               .reshape(feed_shape).astype(np.float32)
    
    # if normal
    elif feed_type == 'normal':
        # create an array of random samples from a normal distribution,
        # in float32 format for compaitbility with tf.keras
        feed = np.random.normal(.5, .13, math.prod(list(feed_shape)))               .reshape(feed_shape).astype(np.float32)
    
    # and in case of a typo
    else:
        # return an invalid value
        feed = np.nan
        
    
    return feed




    
def generate_ade(image, model, target_label, channel=False,
                 max_itr=1000, target_loss=0, target_conf=1.01):
    '''
    This function takes an input template and generates
    an ADversarial Example (ADE) with it. To generate
    an ADE, the function use a gradient-based
    optimization attack which iteratively runs gradient
    descent on the input image.
    
    
    Arguments
    ---------------------------------------------
    ** image (numpy.ndarray)
    
    ** model (tf.keras.Sequential())
    a neural network compatible with the image
    as its input
    
    ** target label (int):
    The class index, we want the model tags the ADE with
    
    ** channel (Boolean, default=False)
    The argument should be set to True, if the image
    is 2D and it needs a channel dimension before being
    fed to a convolutional layer.
    
    
    To stop the gradient descent loop we need more reliable
    stop conditions rather than simply the ADE being tagged
    with the target label. This function uses three stop
    conditions:
    
        ** max_itr (int, default=1,000)
        Maximum number of running gradient descent on the input
        image before break the loop.

        ** target_loss (float, default=0)
        The loss of the model classifying the final ADE with the
        target label

        ** target_conf (float [0 1], default=np.nan)
        The confidence of the model classifying the final ADE with
        the target label

    
    Returns
    ---------------------------------------------
    ** ADE (numpy.ndarray)
    
    ** loss (float)
    The loss of the model classifying the final ADE with the
    target label
    
    ** confidence (float)
    The model's confidence in the classification of the
    image with 2 decimal digits
    
    ** iteration (int)
    The number of rounds modifying the input image with the
    gradient descent, before reaching an ADE with desired
    conditions
    
    ** success (Boolean)
    The value indicate whether the final ADE is classified by
    the model with the target label or not.

    ** duration (HH:MM:SS)
    Time spent for generating an ADE
    
    '''
    
    # check channel value to see if the image needs a
    # dimension for channels 
    if channel:
        # if yes, add an extra 3rd dimension for the channel
        the_input = np.expand_dims(image, axis=[0,3])
    else:
        # otherwise, only add 0th dimension for simulating the batch
        the_input = np.expand_dims(image, axis=0)
        
    # convert the image to a tf-variable
    the_input = tf.Variable(the_input)
    
    # set the initial values of stop conditions
    itr = 1
    loss = np.inf
    conf = 0
    
    # set a check to ensure the model tags the generated
    # ADE with the target label (by turning the value to True)
    success = False
    
    # set the functions for computing loss and optimization
    cost = tf.keras.losses.SparseCategoricalCrossentropy()
    optimizer = tf.keras.optimizers.SGD()
    
    # keep the time for the ADE-generating process
    start = timeit.default_timer()

    # set the stop condition for gradient descent loop
    while (max_itr - itr) and (loss > target_loss) and (conf < target_conf):
        # set the tape
        with tf.GradientTape() as tape:
            # and add the image to be watched
            tape.watch(the_input)
            # feed the image to the model, while keeping
            # the weights intact and get the prediction
            prediction = model(the_input, training=False)
            # calculate the confidence
            conf = round(prediction.numpy().max(), 2)
            # calculate the loss value
            loss = cost(target_label, prediction)
            
        
        # calculate the gradients of the image, wrt loss value
        grads = tape.gradient(loss, the_input)
        # optimize the image, regarding the gradients
        optimizer.apply_gradients(zip([grads], [the_input]))
        
        # increase the iteration index
        itr += 1
    
    # stop the time for the ADE-generating process
    duration = timeit.default_timer() - start
    # convert the time from ms to HH:MM:SS
    duration = time.strftime('%H:%M:%S', time.gmtime(duration))
        

    # if the model classified the generated ADE with the
    if prediction.numpy().argmax() == target_label:
        # target label, record the ADE as successful
        success = True
    
    # if an artificial channel dimension has already been
    # added to the image
    if channel:
        # remove it before returning the ADE
        ade = the_input.numpy()[0,:,:,0]
    # otherwise
    else:
        # just remove the batch dimension
        ade = the_input.numpy()[0,:,:]

        
    return ade, loss.numpy(), conf, itr, success, duration





def generate_aai(labels_list, feed_type, feed_shape, model, n, channel=False,
                 max_itr=1000, target_loss=0, target_conf=1.1, stats=None):
    ''' This function gets a list of target labels and
    generates ADEs with them in balk and calculates the
    Average Adversarial Images (AAIs).
    
    
    Arguments
    ---------------------------------------------
    ** labels_list (list of ints)
    A list of target labels which we want to generate
    AAIs for
    
    ** feed type (str)
    The input image's templates which will be modified to ADEs.
    This argument can take the following values:
        - 'blank': all values being zero
        - 'constant': all values being .2
        - 'uniform': random values from a uniform distribution,
                     in range [0 1]
        - 'normal': random values from a normal distribution,
                    with mu=.5 and std=.13
    
    ** feed shape (tuple of ints)
    Shape of the input image's template
    
    ** model (tf.keras.Sequential())
    a neural network compatible with the input image's
    template
    
    ** n (int)
    The number of ADEs to averaged over in the AAI
    
    ** channel (Boolean, default=False)
    The argument should be set to True, if the input image
    is 2D and it needs a channel dimension before being
    fed to a convolutional layer.
    
    
    To stop the gradient descent loop we need more reliable
    stop conditions rather than simply the ADE being tagged
    with the target label. This function uses a disjunctive
    statement of three stop conditions:
    
        ** max_itr (int, default=1,000)
        Maximum number of running gradient descent on the input
        image before break the loop.

        ** target_loss (float, default=0)
        The loss of the model classifying the final ADE with the
        target label

        ** target_conf (float [0 1], default=np.nan)
        The confidence of the model classifying the final ADE with
        the target label
    
    
    ** stats (pandas.dataframe, default=None)
    A dataframe for recording the stats of generating AAIs which has /
    should have the following columns:
    - 'target label': (int) the target label
    - 'feed': (str) the input type used for generating ADEs 
    - 'aai': (numpy.ndarray) the AAI
    - 'ades#': (int) the number of adversarial examples accumulated
                     in the AAI (successful number of trials)
    - 'trials': (int) the target number of / trials for generating ADEs
    - 'success': (float) ades#/trials with two decimal digits
    - 'avg loss': (float) the average loss value of the final ADEs
    - 'avg iter': (int) the average number of iterations of modifying
                        the input image to get the final ADE
    - 'avg conf': (float) the average confidence of the model
                          in classifying the final ADEs
    - 'prediction': (int) the class index which the generated AAI is
                          tagged with by the network
    - 'confidence': (float) the confidence of the network in labeling
                            the AAI
    - 'duration': (HH:MM:SS) time spent for generating the AAI
    In the case of defult value, this argument create a new dataframe
    with these columns for recording the data.
    
    
    Returns
    ---------------------------------------------
    ** The state dataframe
    
    '''
    # if there's no dataframe passed to the function
    if not stats:
        # set one with the required columns 
        columns = ['target label', 'feed', 'aai', 'ades#', 'trials', 'success',
                   'avg loss', 'avg iter', 'avg conf', 'prediction', 'confidence', 'duration']
        stats = pd.DataFrame(columns = columns)
            
            
    # for each target label
    for label in labels_list:
        
        # set a placeholder for the AAI
        aai = np.zeros(feed_shape).astype(np.float32)
        
        # initialize the averaging variables
        loss_avg = 0
        iter_avg = 0
        conf_avg = 0
        # and the success count
        success_count = 0
        
        
        # keep the time for the ADE-generating process        
        start = timeit.default_timer()
        
        # for n times
        for trial in tqdm(range(n)):
            
            # generate an input with the specified type and shape 
            feed = generate_input(feed_type, feed_shape)
            
            # generate an ADE and get all returned values except the duration
            ade, loss, conf, iteration, success, _ = generate_ade(feed, model, label, channel,
                                                            max_itr, target_loss, target_conf)
            
            # if the ADE is tagged with the target label 
            if success:
                # add it to the accumulated aai
                aai += ade
                # add the loss value of model the average loss
                loss_avg += loss
                # add the number of iterations to the average iteration
                iter_avg += iteration
                # add the confidence of the model to the average loss
                conf_avg += conf
                # increase the number of successful ADEs one unit
                success_count += 1

                
        # stop the time for the AAI-generating process
        duration = timeit.default_timer() - start
        # convert the time from ms to HH:MM:SS
        duration = time.strftime('%H:%M:%S', time.gmtime(duration))
        
        # calculate the success rate of ADE generating 
        # by dividing the successful trials by the total trials
        # and round it to two decimal digits
        success_rate = round(success_count / n, 2)
        
        
        # in the case at least one successful ADE has been generated
        if success_count:
            # normalize the loss average by the number of successful ADEs
            loss_avg /= success_count
            # normalize the number of iterations by the number of successful ADEs
            # and round it to an integer
            iter_avg = round(iter_avg / success_count)
            # normalize the model confidence by the number of successful ADEs
            # and round it to two decimal digits
            conf_avg = round(conf_avg / success_count, 2)
            
        # otherwise set all these values to None    
        else:
            loss_avg = None
            iter_avg = None
            conf_avg = None
            
        
        # feed the computed AAI to the model and 
        # get the prediction and 
        pred, conf = classify(aai, model, channel)


        # fill a row in the stats dataframe with the obtained info
        stats.loc[label] = [label, feed_type, aai, success_count, n, success_rate,
                            loss_avg, iter_avg, conf_avg, pred, conf, duration]

        
    return stats





def plot_ten(images_list, generic_title):
    '''
    This function plot a list of 10 images.
    
    
    Arguments
    ---------------------------------------------
    ** images_list (list of numpy.ndarrays)
    A list of exactly 10 images in form of 2D arrays

    ** generic_title (str)
    A generic text which will be specified by digits
    1 to 10 and assigned to each image

    
    Returns
    ---------------------------------------------
    None
    
    '''
    # set up the frame for plots
    fig, axs = plt.subplots(2, 5, figsize=(20,8))

    for i in range(10):
        
        # in case of intermediate result
        # inspection with less than full ten images
        try:
            # plot the image
            axs[i//5, i%5].imshow(images_list[i], cmap='Blues_r')
            # turn off the ticks
            axs[i//5, i%5].axis('off')
            # add the title
            axs[i//5, i%5].set_title(generic_title + ' ' + str(i))
        
        # ignore the missing images 
        except:
            pass



        
def save_figs(images_list, generic_name):
    '''
    This function save a list of 10 images.
    
    
    Arguments
    ---------------------------------------------
    ** images_list (list of numpy.ndarrays)
    A list of exactly 10 images

    ** generic_name (str)
    A generic text which will be specified by digits
    1 to 10 and will titled the saved files

    
    Returns
    ---------------------------------------------
    None
    
    '''
    # for each image
    for i in range(10):
        # set the figure frame
        fig = plt.figure()
        # set the size
        fig.set_size_inches(8, 8)
        # turn off the tickes
        plt.axis('off')
        # plot the image
        plt.imshow(images_list[i], cmap='Blues_r')
        # save it
        plt.savefig(generic_name + ' ' + str(i), dpi = 100)

        # and close it
        plt.close(fig)
        

