import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from scriptify import scriptify

import tensorflow as tf
import numpy as np

from attack.plain_counterfactual import ProbitLayer
from attack.plain_counterfactual import search_counterfactuals
from latent_space.utils import get_data
import latent_space.models as GM
from metrics import invalidation
from model_configs import getPretrainedNetworks

from os import path

from sklearn.preprocessing import MinMaxScaler

from alibi.explainers import CounterFactualProto


def load_from_numpy(root, strategy, search_method, data, split):
    c = np.load(f"{root}/{strategy}_{data}_{search_method}_{split}.npz")
    return c['counterfactuals'], c['is_adv']


def unify_the_output_shape(model,
                           model_paths,
                           loss='categorical_crossentropy',
                           metrics=['acc'],
                           optimizer='adam'):

    re_compile = False
    inputs = model.input
    outputs = model.output

    if hasattr(model_paths, 'needs_layer') and model_paths.needs_layer:
        if len(outputs.shape) == 2 and outputs.shape[1] > 1:
            outputs = tf.keras.layers.Activation('softmax', name='softmax_pred')(outputs)
        else:
            outputs = tf.keras.activations.sigmoid(outputs)

        re_compile = True

    if outputs.shape[1] == 1:
        outputs = ProbitLayer()(outputs)
        re_compile = True

    if re_compile:
        re_model = tf.keras.models.Model(inputs, outputs)
        re_model.compile(loss=loss, metrics=metrics, optimizer=optimizer)
    else:
        re_model = model

    return re_model


def generate_counterfactuals(model, x, y_sparse, **kwargs):

    scaler = MinMaxScaler()
    scaler.fit(x)

    if kwargs['vae'] is not None:
        x = scaler.transform(x)

    c, _, is_adv = search_counterfactuals(model,
                                          x,
                                          y_sparse,
                                          vae_samples=256,
                                          return_probits=True,
                                          transform=scaler.inverse_transform,
                                          **kwargs)

    # c_pred = np.argmax(c_pred_probits, -1)
    # valid_c = c[c_pred != y_sparse]

    return c, is_adv


def alibi_generate_counterfactuals(cf,
                                   data_x,
                                   data_y_sparse,
                                   threshold=0.3,
                                   k=20):
    c = []
    c_labels = []

    pb = tf.keras.utils.Progbar(target=data_x.shape[0])
    for x in data_x:
        explanation = cf.explain(x[None, :], k=k, threshold=threshold)
        c.append(explanation['cf']['X'])
        c_labels.append(explanation['cf']['class'])
        pb.add(1)

    c_labels = np.array(c_labels)
    is_adv = c_labels != data_y_sparse
    c = np.vstack(c)

    return c, is_adv


if __name__ == "__main__":

    @scriptify
    def script(data,
               baseline_strategy="baseline",
               comparing_strategy="loo,rs",
               use_testset=True,
               use_pred_as_labels=True,
               overwrite=False,
               number_of_models=100,
               batch_size=256,
               gpu=0,
               use_the_existed=True,
               search_method="CW",
               epsilon=0.3,
               expected_confidence=0.8,
               stepsize=None,
               steps=100,
               num_interpolation=10,
               rns_a=0.5,
               rns_epsilon_ratio=1.0,
               rns_steps=200,
               rns_stepsize=None,
               rns_p=2,
               rns_use_kl=False,
               counterfactual_path="invalidation_rns_stab",
               save_path="invalidation_rns_stab",
               vae_latent_dim=32,
               vae_arch_string='d512.d256.d128',
               vae_weights=None,
               threshold=0.3,
               k_tree=20):

        if search_method == 'Proto':
            tf.compat.v1.disable_eager_execution()
        # else:
        #     tf.compat.v1.enable_eager_execution()

        gpus = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
        device = gpus[gpu]

        for device in tf.config.experimental.get_visible_devices('GPU'):
            tf.config.experimental.set_memory_growth(device, True)

        (X_train, y_train), (X_test, y_test), _ = get_data(data)
        model_paths = getPretrainedNetworks(data)

        if use_testset:
            data_x = X_test
            data_y_sparse = y_test
            split = 'test'
        else:
            data_x = X_train
            data_y_sparse = y_train
            split = 'train'

        data_min_elementwise = X_train.min(axis=0)
        data_max_elementwise = X_train.max(axis=0)

        clamp = [data_min_elementwise, data_max_elementwise]

        if baseline_strategy == 'baseline':
            modelA = tf.keras.models.load_model(model_paths.baseline)
            baseline_strategy = 'baseline'
        elif baseline_strategy.startswith("loo"):
            i_th = int(baseline_strategy.split(',')[1])
            modelA = tf.keras.models.load_model(model_paths.loo_list[i_th])
            baseline_strategy = f'loo[{i_th}]'
        elif baseline_strategy.startswith("rs"):
            i_th = int(baseline_strategy.split(',')[1])
            modelA = tf.keras.models.load_model(model_paths.rs_list[i_th])
            baseline_strategy = f'rs[{i_th}]'

        modelA = unify_the_output_shape(modelA, model_paths)

        if use_pred_as_labels:
            data_y_sparse = np.argmax(
                modelA.predict(data_x, batch_size=batch_size), -1)

        if use_the_existed:
            if path.exists(
                    f"{save_path}/{baseline_strategy}_{data}_{search_method}_{split}.npz"
            ) and not overwrite:
                modelA_counterfactuals, is_adv = load_from_numpy(
                    counterfactual_path, baseline_strategy, search_method,
                    data, split)
                print("Find Exsiting Files ...")
            else:
                if search_method == 'Proto':

                    ## the following parameters are default from https://docs.seldon.io/projects/alibi/en/stable/methods/CFProto.html
                    ## Except we find using 1000 iterations take too much longer so we use 100 iterations instead

                    shape = (1, ) + X_train.shape[1:]
                    gamma = 100.
                    theta = 100.
                    c_init = 1.
                    c_steps = 2
                    max_iterations = 100
                    feature_range = (X_train.min(), X_train.max())

                    cf = CounterFactualProto(modelA,
                                             shape,
                                             use_kdtree=True,
                                             theta=theta,
                                             feature_range=feature_range,
                                             max_iterations=max_iterations)
                    cf.fit(X_train,
                           trustscore_kwargs=None)  # find class prototypes
                    modelA_counterfactuals, is_adv = alibi_generate_counterfactuals(
                        cf,
                        data_x,
                        data_y_sparse,
                        threshold=threshold,
                        k=k_tree)

                else:
                    if search_method in ['AE', 'VAE']:

                        # The training of VAE/AE sacles the data into [0, 1]
                        clamp = [0, 1]

                        vae = getattr(GM, search_method)(
                            data_x.shape[1:],
                            vae_latent_dim,
                            arch_string=vae_arch_string,
                            normalize=clamp)
                        vae.compile(
                            optimizer='adam',
                            loss=tf.keras.losses.BinaryCrossentropy(),
                            metrics=[tf.keras.losses.MeanSquaredError()])
                        _ = vae(tf.zeros((1, ) + data_x.shape[1:]))

                        # load the pre-trained VAE/AE model
                        vae.load_weights(vae_weights)
                    else:
                        vae = None

                    modelA_counterfactuals, is_adv = generate_counterfactuals(
                        modelA,
                        data_x.copy(),
                        data_y_sparse.copy(),
                        vae=vae,
                        clamp=clamp,
                        batch_size=batch_size,
                        search_method=search_method,
                        epsilon=epsilon,
                        confidence=expected_confidence,
                        stepsize=stepsize,
                        steps=steps,
                        num_class=model_paths.num_classes,
                        num_interpolation=num_interpolation,
                        rns_a=rns_a,
                        rns_epsilon_ratio=rns_epsilon_ratio,
                        rns_steps=rns_steps,
                        rns_stepsize=rns_stepsize,
                        rns_p=rns_p,
                        rns_use_kl=rns_use_kl)

        modelA_counterfual_probits = modelA.predict(modelA_counterfactuals,
                                                    batch_size=batch_size)
        modelA_counterfual_preds = np.argmax(modelA_counterfual_probits, -1)

        np.savez(
            f"{save_path}/{baseline_strategy}_{data}_{search_method}_{split}.npz",
            counterfactuals=modelA_counterfactuals,
            is_adv=np.array(is_adv),
            pred=modelA_counterfual_preds,
            confidence=np.max(modelA_counterfual_probits, -1))

        data_y_sparse = data_y_sparse[is_adv]
        data_x = data_x[is_adv]

        ################################################################
        # iv: invalidation rate
        # mBc_c: counfidence of counterfactuals predicted by model B
        # d_x2c: distance bewteen data and counterfactuals
        ################################################################

        all_iv = []
        all_mBc_c = []

        for t in comparing_strategy.split(','):
            if t == 'baseline':
                modelB = tf.keras.models.load_model(model_paths.baseline)
                modelB = unify_the_output_shape(modelB, model_paths)

                iv, (mBc_c, mBc_p) = invalidation(modelA_counterfactuals,
                                                  modelA_counterfual_preds,
                                                  modelB,
                                                  batch_size=batch_size,
                                                  aggregation=None,
                                                  return_pred_B=True)

                all_iv.append(iv)
                all_mBc_c.append(mBc_c.mean())

                np.savez(f"{save_path}/{t}_{data}_{search_method}_{split}.npz",
                         counterfactuals=modelA_counterfactuals,
                         pred=mBc_p,
                         confidence=mBc_c)

            elif t == 'loo':
                for i in range(number_of_models):
                    modelB = tf.keras.models.load_model(
                        model_paths.loo_list[i])
                    modelB = unify_the_output_shape(modelB, model_paths)

                    iv, (mBc_c, mBc_p) = invalidation(modelA_counterfactuals,
                                                      modelA_counterfual_preds,
                                                      modelB,
                                                      batch_size=batch_size,
                                                      aggregation=None,
                                                      return_pred_B=True)

                    all_iv.append(iv)
                    all_mBc_c.append(mBc_c.mean())

                    np.savez(f"{save_path}/{t}_{i}_{data}_{split}.npz",
                             counterfactuals=modelA_counterfactuals,
                             pred=mBc_p,
                             confidence=mBc_c)

            elif t == 'rs':
                for i in range(number_of_models):
                    modelB = tf.keras.models.load_model(model_paths.rs_list[i])
                    modelB = unify_the_output_shape(modelB, model_paths)

                    iv, (mBc_c, mBc_p) = invalidation(modelA_counterfactuals,
                                                      modelA_counterfual_preds,
                                                      modelB,
                                                      batch_size=batch_size,
                                                      aggregation=None,
                                                      return_pred_B=True)

                    all_iv.append(iv)
                    all_mBc_c.append(mBc_c.mean())

                    np.savez(
                        f"{save_path}/{t}_{i}_{data}_{search_method}_{split}.npz",
                        counterfactuals=modelA_counterfactuals,
                        pred=mBc_p,
                        confidence=mBc_c)

            else:
                raise ValueError(f"{t} is not a valid training_strategy")

        avg_iv, avg_std = np.mean(all_iv), np.std(all_iv)

        avg_mBc_c = np.mean(all_mBc_c)
        avg_d_x2c = np.mean(
            np.linalg.norm(data_x - modelA_counterfactuals, axis=-1))
        avg_d_x2c_std = np.std(
            np.linalg.norm(data_x - modelA_counterfactuals, axis=-1))
        avg_d_x2c_l1 = np.mean(
            np.linalg.norm(data_x - modelA_counterfactuals, axis=-1, ord=1))
        avg_d_x2c_l1_std = np.std(
            np.linalg.norm(data_x - modelA_counterfactuals, axis=-1, ord=1))
        avg_mAc_c = np.max(modelA_counterfual_probits, -1).mean()
        success_rate = np.mean(is_adv)

        print("\n\n")
        print(
            ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Result <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
        )
        print(
            f"Success Rate [EPS={epsilon}, {search_method}]:     {success_rate}"
        )
        print(f"Invalidation Rate:     {avg_iv}+-{avg_std}")
        print(
            f"Confidence on Counterfactual ({baseline_strategy}):     {avg_mAc_c}"
        )
        print(
            f"Confidence on Counterfactual ({comparing_strategy}):     {avg_mBc_c}"
        )
        print(
            f"L2 Distance (Data Range: {np.min(clamp[0]), np.max(clamp[1])}):     {avg_d_x2c}+-{avg_d_x2c_std}"
        )
        print(
            f"L1 Distance (Data Range: {np.min(clamp[0]), np.max(clamp[1])}):     {avg_d_x2c_l1}+-{avg_d_x2c_l1_std}"
        )

        return_dict = {
            "IV": round(float(avg_iv), 2),
            "IV_std": round(float(avg_std), 2),
            "mAc_c": round(float(avg_mAc_c), 2),
            "mBc_c": round(float(avg_mBc_c), 2),
            "d_x2c": round(float(avg_d_x2c), 4),
            "d_x2c_l1": round(float(avg_d_x2c_l1), 4),
            "d_x2c_std": round(float(avg_d_x2c_std), 4),
            "d_x2c_l1_std": round(float(avg_d_x2c_l1_std), 4),
            "success_rate": round(float(success_rate), 2),
        }

        return return_dict
