import os
import sys

import foolbox as fb
from foolbox.attacks import *
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
#from runs.yaml_to_params import getParams, getFunctionFromModules
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import wandb

def transform_to_logit(loaded_model, last_layer, reverse=1.0, biais = 0.0):
    #loaded_model.summary()

    last_layer_g = last_layer
    print(loaded_model.layers[0].input.shape[1:])
    output_logits = loaded_model.layers[last_layer].output.shape[-1]
    print("last layers has "+str(output_logits)+" logit(s)")
    if output_logits>2:
        if last_layer_g!=-1:
            new_model = Sequential()
            new_model.add(Model(loaded_model.inputs, loaded_model.layers[last_layer_g].output))
        else: 
            new_model=loaded_model
        new_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=5e-3), metrics=["accuracy"])
        return new_model
    assert(output_logits<=2)  ## binary classification only: either a single or two outputs NN
        
    if output_logits==1:
        logits = Sequential()
        print(loaded_model.layers[last_layer].name)
        if 'batch_normalization' in loaded_model.layers[last_layer].name:
            print("WARNING : batchnorm layer may induce wrong adversarial sample")
        logits.add(Model(loaded_model.inputs, loaded_model.layers[last_layer].output))
        print(logits.layers[-1].layers[-1])

        logits.add(Dense(2, activation=None))   ## No softmax for logits
        logits.layers[-1].set_weights([reverse*np.asarray([[-1,1]]),np.asarray([reverse*biais,-reverse*biais])])
        last_layer_g = -1
    else:
        logits = loaded_model
    ## logits is a 2 outputs NN 
    new_model = Sequential()
    new_model.add(Model(logits.inputs, logits.layers[last_layer_g].output))
    #new_model.add(Activation('softmax'))   
    new_model.summary() 
    new_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=5e-3), metrics=["accuracy"])
    return new_model #, logits, loaded_model


def compute_adversarial_robustness(model, attack_fct, test_gen, test_size,batch_size,last_layer = -1,force_recompute=False,bounds=[-1,1],eps=[]):
    """
    compute the robustness of a network by perfoming the attack on the n samples.
    """
    #df_results = {#'preds':np.array([]).reshape(0,model.output_shape[-1]),
    #    'margin':[], 'adv_noise_norm':[]}
     # else recompute advs, and store it on disk  
    if model is None:
        return None

    small_float = (
        np.finfo(dtype="float32").min / 2.0
    )  ## to discard y_true prediction in np.max adding a very small float
    model = transform_to_logit(model, last_layer=last_layer)

    f_model = fb.TensorFlowModel(model, bounds=bounds, device="/GPU:0")
    cols_name = ["preds_" + str(i) for i in range(model.output_shape[-1])] +[
        
                "e_"+str(e) for e in eps]+ [
        "margin",
        "success_attack",
        "margin_adv",
        "adv_noise_norm_l2",
        "adv_noise_norm_linf",
    ]
    df_results = pd.DataFrame(columns=cols_name)



    for i in tqdm(range(test_size//batch_size)):
        (x_batch,y_true) = next(test_gen)
        y_true_ord=np.argmax(y_true,axis=1)
        x=tf.convert_to_tensor(x_batch, dtype="float32")
        y=tf.convert_to_tensor(y_true_ord, dtype="int32")
        #model.evaluate(x,y)
        out_pred = model.predict(x)
        #print(out_pred.shape)
        # df_results['preds']=np.concatenate([df_results['preds'],out_pred])
        H = np.max(out_pred + y_true * small_float, axis=1, keepdims=True)
        L = np.sum((out_pred - H) * y_true, axis=1)
        # df_results['margin']=np.concatenate([df_results['margin'],L])
        well_classified = (np.argmax(out_pred, axis=1) == y)
        # output = attack_fct(f_model, x, y, epsilons=[0.2]) #
        _, advs, success = attack_fct(f_model, x, y, epsilons=eps)
        #print(well_classified[:10])
        #print("0",success.shape,well_classified.shape )
        #print(success[0,:10])
        #success = success & well_classified
        #print("1",success.shape)
        #print(success[:,:2])
        out_adv = model.predict(advs)
        H_adv = np.max(out_adv + y_true * small_float, axis=1, keepdims=True)
        L_adv = np.sum((out_adv - H_adv) * y_true, axis=1)
        noise = advs - x
        noise_norm = np.linalg.norm(np.reshape(noise, (x.shape[0], -1)), ord=2, axis=1)
        noise_norm = noise_norm * well_classified.numpy()
        noise_norm_inf = np.linalg.norm(np.reshape(noise, (x.shape[0], -1)), ord=np.inf, axis=1)
        noise_norm_inf = noise_norm_inf * well_classified.numpy()
        # df_results['adv_noise_norm']=np.concatenate([df_results['adv_noise_norm'],noise_norm*well_classified.numpy()])
        df_results = pd.concat([df_results, pd.DataFrame(np.concatenate(
            [out_pred,success.numpy().T,L[:, np.newaxis], success.numpy()[0, :, np.newaxis], L_adv[:, np.newaxis],
             noise_norm[:, np.newaxis], noise_norm_inf[:, np.newaxis]], axis=1), columns=cols_name)], ignore_index=True)


    return df_results


def robustness_stats(df):
    eps = []
    c_eps = []
    for col in df.columns:
        if col.startswith( 'e_' ):
            eps.append(float(col[2:]))
            c_eps.append(col)
    accs = []
    accs2 = []
    for v,c in zip(eps,c_eps):
        values = df[c].values
        accs.append(values[values==0].shape[0]/values.shape[0])
        diff = df['margin'].values*(1-values)-v
        #print(v, diff[diff>0].shape[0]/values.shape[0])
        accs2.append(diff[diff>0].shape[0]/values.shape[0])
    res = {'eps': eps, 'emp_acc': accs, 'th_acc': accs2, 'adv_l2': df['adv_noise_norm_l2'].mean(),
           'adv_lin': df['adv_noise_norm_linf'].mean()}
    return res


def draw_accuracy(res):
    ax = sns.lineplot(x=res['eps'], y=res['emp_acc'], label="Emp ACC")
    ax = sns.lineplot(x=res['eps'], y=res['th_acc'], label="th ACC")
    ax.set_ylim(0, 1)

def wandb_log_robustness(df):
    res = robustness_stats(df)
    data = [[x, y] for (x, y) in zip(res['eps'], res['emp_acc'])]
    table = wandb.Table(data=data, columns=["eps", "ACC"])
    wandb.log({"Emp rob acc": wandb.plot.line(table,
                                                    "eps", "ACC", title="Empirical robustness accuracy")})
    data = [[x, y] for (x, y) in zip(res['eps'], res['th_acc'])]
    table = wandb.Table(data=data, columns=["eps", "ACC"])
    wandb.log({"Th rob acc": wandb.plot.line(table,
                                                    "eps", "ACC", title="Theoritical robustness accuracy")})

    wandb.log({'adv_l2': res['adv_l2']})
    wandb.log({'adv_lin': res['adv_lin']})
