import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.inception_v3 import InceptionV3
from vit_keras import vit
from WideResNet import WRN
import instances
import random
import ssl
import pdb
import cv2

ssl._create_default_https_context = ssl._create_unverified_context

# GENERAL PARAMETERS
MODE = 'joint_untargeted'
MODEL = 'Inceptionv3' # (options: 'ResNet50', 'Inceptionv3', 'ViT-B', 'WideResNet')

def inv_softmax(x, C):
   return tf.math.log(x) + C

if MODEL == 'ResNet50':
    orig_model = ResNet50(weights='imagenet')
    IMG_SHAPE = [224, 224, 3]
elif MODEL == 'Inceptionv3':
## For Inceptionv3, ALSO CHANGE THE n_PIXELS IN CONFIG
    orig_model = InceptionV3(weights='imagenet')
    IMG_SHAPE = [299, 299, 3]
elif MODEL == 'ViT-B':
    orig_model = vit.vit_b16(image_size=224,
                            activation='sigmoid',
                            pretrained=True,
                            include_top=True,
                            pretrained_top=True)
    IMG_SHAPE = [224, 224, 3]
elif MODEL == 'WideResNet':
    orig_model = WRN()
    IMG_SHAPE = [224, 224, 3]
else:
    raise Exception("Model not implemented")

model = orig_model if MODEL == 'WideResNet' else tf.keras.Model(orig_model.input, 
    tf.keras.layers.Lambda(lambda x: inv_softmax(x, tf.math.log(10.)), name='inv_softmax')(orig_model.output))

generator = instances.load_generator(model=MODEL)

def get_data_sample(index):
    if MODEL == 'Inceptionv3':
        return (
            generator[index]*127.,
            os.path.splitext(os.path.split(generator.filenames[index])[1])[0],
        )
    elif MODEL == 'ResNet50':
        return (
            generator[index],
            os.path.splitext(os.path.split(generator.filenames[index])[1])[0],
        )
    elif MODEL == 'ViT-B':
        return (
            generator[index].reshape(1, 224, 224, 3),
            os.path.splitext(os.path.split(generator.filenames[index])[1])[0],
        )
    elif MODEL == 'WideResNet':
        return (
            generator[index]*255.-127.,
            os.path.splitext(os.path.split(generator.filenames[index])[1])[0],
        )
    else:
        raise Exception("Model not implemented")


def store_single_result(mapping, var, fname, rate, d, test_name):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    mapping = np.reshape(mapping, IMG_SHAPE)
    mapping = mapping.squeeze()

    if MODEL in ['ResNet50', 'ViT-B']:
        mapping = mapping[:,:,::-1]

    mapping = mapping - np.min(mapping)
    mapping = mapping / np.max(mapping)

    plt.imsave(
        os.path.join(
            savedir,
            f'{var}_rate-{rate}_d-{d}.png',
        ),
        mapping,
        vmin=np.min(mapping),
        vmax=np.max(mapping),
        format='png',
    )


def store_s(s, fname, rate, d, test_name):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    s = np.reshape(s, IMG_SHAPE)
    plt.imsave(
        os.path.join(
            savedir,
            f's_rate-{rate}_d-{d}.png',
        ),
        np.mean(s, axis=-1).squeeze(),
        cmap='Reds',
        vmin=0.0,
        vmax=1.0,
        format='png',
    )


def store_pert_img(x, s, p, fname, rate, d, test_name, optim):
    if rate == 0:
        savedir = os.path.join('results', test_name, fname)
    else:
        savedir = os.path.join('results', test_name, fname, f'{rate}')
    os.makedirs(savedir, exist_ok=True)
    x = np.reshape(x, IMG_SHAPE)
    s = np.reshape(s, IMG_SHAPE)
    p = np.reshape(p, IMG_SHAPE)

    if optim == 'joint':
        pert_x = x + s * p
        # scrossp = s*p
    elif optim == 'univariate':
        pert_x = x + p
    else:
        raise Exception("optim not implemented")

    pert_x = pert_x.squeeze()
    if MODEL in ['ResNet50', 'ViT-B']:
        pert_x = pert_x[:,:,::-1]

    pert_x = np.clip(pert_x, a_min=-127., a_max=127.)
    pert_x = pert_x - np.min(pert_x)
    pert_x = pert_x / np.max(pert_x)

    cv2.imwrite(os.path.join(
            savedir,
            f'pertimg_rate-{rate}_d-{d}.png'
        ), pert_x*255)

    if optim=='joint':
        p = (d/255.)*(p-np.min(p))/(np.max(p)-np.min(p))
        pxs = s*p
        pxs = pxs.squeeze()
        if MODEL in ['ResNet50', 'ViT-B']:
            pxs = pxs[:,:,::-1]
        pxs = (pxs - np.min(pxs))
        pxs = pxs / np.max(pxs)

        cv2.imwrite(os.path.join(
                savedir,
                f'pxs_rate-{rate}_d<-{d}.png',
            ), pxs*255.)


def store_saliency_importance(joint_s, rates, fname, d, test_name):
    savedir = os.path.join('results', test_name, fname)
    os.makedirs(savedir, exist_ok=True)
    joint_s = np.reshape(joint_s, [len(rates)] + IMG_SHAPE)
    joint_s = np.mean(np.sum(joint_s, axis=0), axis=-1)
    joint_s = (joint_s - np.min(joint_s)) / np.max(joint_s)

    plt.imsave(
        os.path.join(
            savedir,
            f'impmap_d-{d}.png'
        ),
        joint_s,
        cmap='Reds',
        vmin=np.min(joint_s),
        vmax=np.max(joint_s),
        format='png',
    )

def get_distortion(x, model=model, mode=MODE, optim="joint"):

    x_tensor = tf.constant(x, dtype=tf.float32)
    s_flat = tf.placeholder(tf.float32, (np.prod(x_tensor.shape),))
    s_tensor = tf.reshape(s_flat, x.shape)

    p_flat = tf.placeholder(tf.float32, (np.prod(x_tensor.shape),))
    p_tensor = tf.reshape(p_flat, x.shape)
    
    if MODEL in ['ResNet50', 'ViT-B']:
        pred = model.predict(x)
    elif MODEL == 'WideResNet':
        pred = model.predict((x + 127.) / 255.)  
    else: # FOR INCEPTIONv3 --> 
        pred = model.predict(x/127.)

    if MODEL != 'WideResNet':
        pred = inv_softmax(pred)
    node = np.argpartition(pred[0, ...], -2)[-1]

    if optim == "joint":
        unprocessed = x + s_tensor * p_tensor
    elif optim == "univariate":
        unprocessed = x + p_tensor + s_tensor*0
    else:
        raise Exception("optim not implemented")

    network_input = tf.clip_by_value(unprocessed, clip_value_min=np.min(x), clip_value_max=np.max(x))

    if MODEL in ['ResNet50', 'ViT-B']:
        out = model(network_input)
    elif MODEL == 'WideResNet':
        out = model((network_input+127.)/255.)  
    else: # FOR INCEPTIONv3 --> 
        out = model(network_input/127.)

    if MODEL != 'WideResNet':
        out = inv_softmax(out)

    if mode == 'untargeted':
        target_node = None
        loss = tf.squeeze(out[..., node])
    elif mode == 'targeted':
        class_li = list(range(1000))
        class_li.remove(node)
        new_class = random.randint(0, 998)
        target_node = class_li[new_class]
        cel = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        loss = cel([target_node], out)
    elif mode == 'targeted-cw':
        class_li = list(range(1000))
        class_li.remove(node)
        new_class = random.randint(0, 998)
        target_node = class_li[new_class]
        tl_onehot = tf.expand_dims(tf.one_hot(target_node, 1000),0)
        target_prob = tf.reduce_sum((tl_onehot*out), 1)
        other = tf.reduce_max((1-tl_onehot)*out, 1)
        confidence = 20.0
        loss = tf.maximum(0.0, tf.squeeze(other-target_prob+confidence))
    else:
        raise Exception("mode not implemented")

    gradient = K.gradients(loss, [s_flat, p_flat])
    f_out = K.function([s_flat, p_flat], [loss])
    f_gradient = K.function([s_flat, p_flat], [gradient])

    return lambda s, p: f_out([s, p])[0], lambda s, p: f_gradient([s, p])[0][0], lambda s, p: f_gradient([s, p])[0][1], node, target_node

def get_s_sum(s):
    s_flat = np.abs(s).reshape(1,3,-1).mean(1,keepdims = True)
    return np.sum(s), np.sum(s_flat!=0)


def get_model_prediction(x, s, p, node, target_node, mode, optim, logfname, fname):

    s = np.reshape(s, x.shape)
    p = np.reshape(p, x.shape)

    if optim == 'joint':
        # pdb.set_trace()
        l1_norm = np.sum(np.abs(s*(p/255)))
        linf_norm = np.max(np.abs(s*p))
        l2_norm = np.linalg.norm(s*(p/255))
        pert_input = x + s * p
    elif optim == 'univariate':
        l1_norm = np.sum(np.abs(p/255))
        linf_norm = np.max(np.abs(p))
        l2_norm = np.linalg.norm(p/255)
        pert_input = x + p
    else:
        raise Exception("optim not implemented")

    pert_input = tf.convert_to_tensor(pert_input)
    pert_input = tf.clip_by_value(pert_input, clip_value_min=np.min(x), clip_value_max=np.max(x))

    if MODEL in ['ResNet50', 'ViT-B']:
        pred0 = model.predict(x, steps=1)
        pred1 = model.predict(pert_input, steps=1)
    elif MODEL == 'WideResNet':
        pred0 = model.predict((x-127.)/255., steps=1)
        pred1 = model.predict((pert_input-127.)/255., steps=1)
    else:  # FOR INCEPTIONv3 -->
        pred0 = model.predict(x/127., steps=1)
        pred1 = model.predict(pert_input/127., steps=1)

    if MODEL != 'WideResNet':
        pred0 = inv_softmax(pred0)
        pred1 = inv_softmax(pred1)

    node0 = np.argpartition(pred0[0, ...], -2)[-1]
    node1 = np.argpartition(pred1[0, ...], -2)[-1]

    pred0_percent = tf.nn.softmax(pred0)[..., node0]
    pred1_old_class_percent = tf.nn.softmax(pred1)[..., node0]
    pred1_new_class_percent = tf.nn.softmax(pred1)[..., node1]

    with tf.Session() as sess:
        print('\n------------------------\n')
        print(f'Filename:  {fname} \n ',
              f'orig pred: {node0} | ',
              f'orig pred: {pred0_percent.eval()}% | ',
              f'pert target: {target_node} | ',
              f'pert pred: {node1} | ',
              f'pert pred new class: {pred1_new_class_percent.eval()}% | ',
              f'pert pred old class: {pred1_old_class_percent.eval()}% | ',
              )
        print('\n------------------------\n')

        if not os.path.exists("../logs"):
            os.mkdir("../logs")
    
        with open(logfname, "a") as f:
            f.write(f"\nFilename: {fname}")
            f.write(f"\norig pred: {node0}")
            f.write(f"\norig pred score: {pred0_percent.eval()}%")
            f.write(f"\npert target: {target_node}")
            f.write(f"\npert pred: {node1}")
            f.write(f"\npert pred new class: {pred1_new_class_percent.eval()}%")
            f.write(f"\npert pred old class: {pred1_old_class_percent.eval()}% \n")

    if mode == 'untargeted':
        return int(node0 != node1), l1_norm, l2_norm, linf_norm
    elif mode == 'targeted' or mode == 'targeted-cw':
        return int(target_node == node1), l1_norm, l2_norm, linf_norm
    else:
        return 0, norm
