import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
import instances
import random
import pdb

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

MODEL = 'cifar10vgg'  # 'NiN', "ResNet18"

if MODEL == 'NiN':
    tf.enable_eager_execution()
    import cifar10_NiN_bn
    model = cifar10_NiN_bn.NiN_Model()
    saver = tf.train.Saver()
    checkpoint = tf.train.latest_checkpoint('./models/cifar10_NiN')
    saver.restore(sess, checkpoint)
elif MODEL == 'ResNet18':
    from keras_contrib.applications.resnet import ResNet18
    model = ResNet18((32, 32, 3), 10)
else:
    from models import cifar10vgg
    model = cifar10vgg()

# GENERAL PARAMETERS
MODE = 'joint_untargeted'

IMG_SHAPE = [32,32,3]
labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

# LOAD MODEL
generator = instances.load_generator()


def get_data_sample(index):
    return (
        generator[index],
        os.path.splitext(os.path.split(generator.filenames[index])[1])[0],
        os.path.split(generator.filenames[index])[0]
    )


def store_single_result(mapping, var, fname, rate, d, test_name):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    # print(mapping.shape)the s
    mapping = np.reshape(mapping, IMG_SHAPE)
    mapping = mapping.squeeze()
    mapping = mapping[:,:,::-1]

    mapping = mapping - np.min(mapping)
    mapping = mapping / np.max(mapping)

    plt.imsave(
        os.path.join(
            savedir,
            f'{var}_rate-{rate}_d-{d}.png',
        ),
        mapping,
        # cmap='Greys',
        vmin=np.min(mapping),
        vmax=np.max(mapping),
        format='png',
    )


def store_s(s, fname, rate, d, test_name):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    s = np.reshape(s, IMG_SHAPE)
    plt.imsave(
        os.path.join(
            savedir,
            f's_rate-{rate}_d-{d}.png',
        ),
        np.mean(s, axis=-1).squeeze(),
        cmap='Reds',
        vmin=0.0,
        vmax=1.0,
        format='png',
    )


def store_pert_img(x, s, p, fname, rate, d, test_name, optim, true_cl=None, clean_lab=None, target_lab=None, pert_lab=None):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    # print(mapping.shape)
    x = np.reshape(x, IMG_SHAPE)
    s = np.reshape(s, IMG_SHAPE)
    p = np.reshape(p, IMG_SHAPE)

    if optim == 'joint':
        pert_x = x + s * p
    elif optim == 'univariate':
        pert_x = x + p
    else:
        raise Exception("optim not implemented")

    pert_x = pert_x.squeeze()
    pert_x = pert_x[:,:,::-1]

    pert_x = np.clip(pert_x, a_min=np.min(x), a_max=np.max(x))
    pert_x = pert_x - np.min(pert_x)
    pert_x = pert_x / np.max(pert_x)

    plt.imsave(
        os.path.join(
            savedir,
            f'pertimg_rate-{rate}_d-{d}.png'
        ),
        pert_x,
        # cmap='Greys',
        vmin=np.min(pert_x),
        vmax=np.max(pert_x),
        format='jpg',
    )

    if true_cl is not None:
        with open(os.path.join(savedir, f'prediction-log-rate-{rate}_d-{d}.txt') , 'w') as f:
            f.write("True label: " + true_cl)
            f.write("\nModel prediction on clean image: " + labels[clean_lab])
            f.write("\nModel prediction on perturbed image: " + labels[pert_lab])



def store_saliency_importance(joint_s, rates, fname, d, test_name):
    savedir = os.path.join('results', test_name, fname)
    os.makedirs(savedir, exist_ok=True)
    joint_s = np.reshape(joint_s, [len(rates)] + IMG_SHAPE)
    joint_s = np.mean(np.sum(joint_s, axis=0), axis=-1)
    joint_s = (joint_s - np.min(joint_s)) / np.max(joint_s)

    plt.imsave(
        os.path.join(
            savedir,
            f'impmap_d-{d}.png'
        ),
        joint_s,
        cmap='Reds',
        vmin=np.min(joint_s),
        vmax=np.max(joint_s),
        format='png',
    )


def get_distortion(x, model=model, mode=MODE, optim="joint"):

    x_tensor = tf.constant(x, dtype=tf.float32)
    s_flat = tf.placeholder(tf.float32, (np.prod(x_tensor.shape),))
    s_tensor = tf.reshape(s_flat, x.shape)

    p_flat = tf.placeholder(tf.float32, (np.prod(x_tensor.shape),))
    p_tensor = tf.reshape(p_flat, x.shape)

    if MODEL == 'NiN':
        pred = sess.run([model.y], feed_dict={model.x_input: x})
        node = np.argpartition(pred[0][0], -2)[-1]
    else:
        pred = model.predict(x)
        node = np.argpartition(pred[0, ...], -2)[-1] # COMMENT ONLY FOR NiN

    if optim == "joint":
        unprocessed = x + s_tensor * p_tensor
    elif optim == "univariate":
        unprocessed = x + p_tensor + s_tensor*0
    else:
        raise Exception("optim not implemented")

    def to_npy(ps):
        return ps.numpy()

    network_input = tf.clip_by_value(unprocessed, clip_value_min=np.min(x), clip_value_max=np.max(x))
    if MODEL == 'NiN':
        ni = tf.py_function(to_npy, [network_input], np.float64)
        out = sess.run([model.y], feed_dict={model.x_input: ni})
    else:
        out = model.model(network_input)
    if mode == 'untargeted':
        target_node = None
        loss = tf.squeeze(out[0][0][node]) if MODEL == 'NiN' else tf.squeeze(out[..., node])
    elif mode == 'targeted':
        class_li = list(range(10))
        class_li.remove(node)
        new_class = random.randint(0, 8)
        target_node = class_li[new_class]
        cel = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        loss = cel([target_node], out[0][0]) if MODEL == 'NiN' else cel([target_node], out)
    elif mode == 'targeted-cw':
        class_li = list(range(10))
        class_li.remove(node)
        new_class = random.randint(0, 8)
        target_node = class_li[new_class]
        tl_onehot = tf.expand_dims(tf.one_hot(target_node, 10),0)
        target_prob = tf.reduce_sum((tl_onehot*out), 1)
        other = tf.reduce_max((1-tl_onehot)*out, 1)
        confidence = 0.0
        loss = tf.squeeze(other-target_prob+confidence)
    else:
        raise Exception("mode not implemented")

    loss_all = tf.squeeze(out[0][0]) if MODEL == 'NiN' else tf.squeeze(out)
    f_out_all = K.function([s_flat, p_flat], [loss_all])

    gradient = K.gradients(loss, [s_flat, p_flat])
    f_out = K.function([s_flat, p_flat], [loss])
    f_gradient = K.function([s_flat, p_flat], [gradient])

    return lambda s, p: f_out([s, p])[0], lambda s, p: f_gradient([s, p])[0][0], lambda s, p: f_gradient([s, p])[0][1], node, target_node, lambda s, p: f_out_all([s, p])

def get_s_sum(s):
    s_flat = np.abs(s).reshape(1,3,-1).mean(1, keepdims = True)
    return np.sum(s), np.sum(s_flat!=0)
    # return np.sum(s)
    
def get_model_prediction(x, s, p, node, target_node, mode, optim, pred1, logfname, fname):

    s = np.reshape(s, x.shape)
    p = np.reshape(p, x.shape)

    if optim == 'joint':
        norm = np.sum(np.abs(s*p))
        linf_norm = np.max(np.abs(s*p))
        pert_input = x + s * p
    elif optim == 'univariate':
        norm = np.sum(np.abs(p))
        linf_norm = np.max(np.abs(p))
        pert_input = x + p
    else:
        raise Exception("optim not implemented")

    pert_input = tf.convert_to_tensor(pert_input)
    pert_input = tf.clip_by_value(pert_input, clip_value_min=np.min(x), clip_value_max=np.max(x))

    if MODEL == 'NiN':
        pred0 = sess.run([model.y], feed_dict={model.x_input: x})
        pred1 = model.predict(pert_input, steps=1)
        node0 = np.argpartition(pred0[0][0], -2)[-1]
        node1 = np.argpartition(pred1[0][0], -2)[-1]
        pred0_percent = tf.nn.softmax(pred0[0][0])[node0]
        pred1_old_class_percent = tf.nn.softmax(pred1[0][0])[node0]
        pred1_new_class_percent = tf.nn.softmax(pred1[0][0])[node1]
    else:
        pred0 = model.predict(x)
        pred1 = np.asarray([i.tolist() for i in pred1])
        node0 = np.argpartition(pred0[0, ...], -2)[-1]
        node1 = np.argpartition(pred1[0, ...], -2)[-1]
        pred0_percent = tf.nn.softmax(pred0)[..., node0]
        pred1_old_class_percent = tf.nn.softmax(pred1)[..., node0]
        pred1_new_class_percent = tf.nn.softmax(pred1)[..., node1]

    with tf.Session() as sess:
        print('\n------------------------\n')
        print(f'orig pred: {labels[node0]} ({node0}) | ',
              f'orig pred: {pred0_percent.eval()}% | ',
              #f'pert target: {labels[target_node]} ({target_node}) | ',
              f'pert pred: {labels[node1]} ({node1}) | ',
              f'pert pred new class: {pred1_new_class_percent.eval()}% | ',
              f'pert pred old class: {pred1_old_class_percent.eval()}% | ',
              )
        print('\n------------------------\n')

        with open(logfname, "a") as f:
            f.write(f"\norig pred: {labels[node0]} ({node0})")
            f.write(f"\norig pred: {pred0_percent.eval()}%")
            #f.write(f"\npert target: {labels[target_node]} ({target_node})")
            f.write(f"\npert pred: {labels[node1]} ({node1})")
            f.write(f"\npert pred new class: {pred1_new_class_percent.eval()}%")
            f.write(f"\npert pred old class: {pred1_old_class_percent.eval()}% \n")

    if mode == 'untargeted':
        return int(node0 != node1), norm, node1, linf_norm
    elif mode == 'targeted' or 'targeted-cw':
        return int(target_node == node1), norm, node1, linf_norm
    else:
        return 0, norm
