# import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import tensorflow as tf
import numpy as np

from scriptify import scriptify

import sys

sys.path.append("../")
from attack import CW
from attack import IGD_L1
from attack import IGD_L2
from attack import RNS
from attack import PGDs
from attack import Simple
from latent_space.utils import get_data
from latent_space.utils import search_z_adv

sys.path.append("../../")
from model_configs import getPretrainedNetworks


def search_counterfactuals(model,
                           x,
                           y_sparse,
                           search_method='CW',
                           vae=None,
                           vae_samples=100,
                           transform=None,
                           num_class=2,
                           rns_a=0.5,
                           rns_epsilon_ratio=1.0,
                           rns_steps=300,
                           rns_stepsize=None,
                           rns_p=2,
                           rns_use_kl=False,
                           **kwargs):
    if 'CW' in search_method:
        adv_x, y_pred_adv, is_adv = CW(model, x, y_sparse, **kwargs)
        adv_x = adv_x[is_adv]
        y_pred_adv = y_pred_adv[is_adv]

        if 'RNS' in search_method:

            if len(y_pred_adv.shape) == 2:
                y = np.argmax(y_pred_adv, -1)
            else:
                y = y_pred_adv

            adv_x, y_pred_adv, _ = RNS(
                model,
                adv_x,
                y,
                a=rns_a,
                res=10,
                baseline=0,
                max_steps=rns_steps,
                adv_step_size=rns_stepsize if rns_stepsize is not None else 2 *
                kwargs['epsilon'] * rns_epsilon_ratio / rns_steps,
                clamp=kwargs['clamp'],
                adv_epsilon=kwargs['epsilon'] * rns_epsilon_ratio,
                p=rns_p,
                num_class=num_class,
                alpha=1,
                return_probits=False if 'return_probits' not in kwargs else
                kwargs['return_probits'],
                batch_size=kwargs['batch_size'],
                use_kl=rns_use_kl)

            # print(np.mean(np.linalg.norm(x - adv_x, axis=-1)))

    elif 'IGD_L1' in search_method:
        adv_x, y_pred_adv, is_adv = IGD_L1(model, x, y_sparse, num_classes=num_class, **kwargs)
        adv_x = adv_x[is_adv]
        y_pred_adv = y_pred_adv[is_adv]

        if 'RNS' in search_method:

            if len(y_pred_adv.shape) == 2:
                y = np.argmax(y_pred_adv, -1)
            else:
                y = y_pred_adv

            adv_x, y_pred_adv, _ = RNS(
                model,
                adv_x,
                y,
                a=rns_a,
                res=10,
                baseline=0,
                max_steps=rns_steps,
                adv_step_size=rns_stepsize if rns_stepsize is not None else 2 *
                kwargs['epsilon'] * rns_epsilon_ratio / rns_steps,
                clamp=kwargs['clamp'],
                adv_epsilon=kwargs['epsilon'] * rns_epsilon_ratio,
                p=rns_p,
                num_class=num_class,
                alpha=1,
                return_probits=False if 'return_probits' not in kwargs else
                kwargs['return_probits'],
                batch_size=kwargs['batch_size'],
                use_kl=rns_use_kl)

    elif 'IGD_L2' in search_method:
        adv_x, y_pred_adv, is_adv = IGD_L2(model, x, y_sparse, num_classes=num_class, **kwargs)
        adv_x = adv_x[is_adv]
        y_pred_adv = y_pred_adv[is_adv]

        if 'RNS' in search_method:

            if len(y_pred_adv.shape) == 2:
                y = np.argmax(y_pred_adv, -1)
            else:
                y = y_pred_adv

            adv_x, y_pred_adv, _ = RNS(
                model,
                adv_x,
                y,
                a=rns_a,
                res=10,
                baseline=0,
                max_steps=rns_steps,
                adv_step_size=rns_stepsize if rns_stepsize is not None else 2 *
                kwargs['epsilon'] * rns_epsilon_ratio / rns_steps,
                clamp=kwargs['clamp'],
                adv_epsilon=kwargs['epsilon'] * rns_epsilon_ratio,
                p=rns_p,
                num_class=num_class,
                alpha=1,
                return_probits=False if 'return_probits' not in kwargs else
                kwargs['return_probits'],
                batch_size=kwargs['batch_size'],
                use_kl=rns_use_kl)


    elif 'PGDs' in search_method:
        adv_x, y_pred_adv, is_adv = PGDs(model, x, y_sparse, **kwargs)
        adv_x = adv_x[is_adv]
        y_pred_adv = y_pred_adv[is_adv]

        if 'RNS' in search_method:

            if len(y_pred_adv.shape) == 2:
                y = np.argmax(y_pred_adv, -1)
            else:
                y = y_pred_adv

            adv_x, y_pred_adv, _ = RNS(
                model,
                adv_x,
                y,
                a=rns_a,
                res=10,
                baseline=0,
                max_steps=rns_steps,
                adv_step_size=rns_stepsize if rns_stepsize is not None else 2 *
                kwargs['epsilon'] * rns_epsilon_ratio / rns_steps,
                clamp=kwargs['clamp'],
                adv_epsilon=kwargs['epsilon'] * rns_epsilon_ratio,
                p=rns_p,
                num_class=num_class,
                alpha=1.0,
                return_probits=False if 'return_probits' not in kwargs else
                kwargs['return_probits'],
                batch_size=kwargs['batch_size'],
                use_kl=rns_use_kl)

        elif 'Simple' in search_method:
            if len(y_pred_adv.shape) == 2:
                y = np.argmax(y_pred_adv, -1)
            else:
                y = y_pred_adv

            adv_x, y_pred_adv, _ = Simple(
                model,
                adv_x,
                y,
                res=10,
                baseline=0,
                max_steps=rns_steps,
                adv_step_size=rns_stepsize if rns_stepsize is not None else 2 *
                kwargs['epsilon'] * rns_epsilon_ratio / rns_steps,
                clamp=kwargs['clamp'],
                adv_epsilon=kwargs['epsilon'] * rns_epsilon_ratio,
                p=rns_p,
                num_class=num_class,
                alpha=1.0,
                return_probits=False if 'return_probits' not in kwargs else
                kwargs['return_probits'],
                batch_size=kwargs['batch_size'],
                use_kl=rns_use_kl)

    elif search_method in ['AE', 'VAE']:
        if vae is None:
            raise ValueError(
                "VAE model is not provided for searching the latent space.")

        z = vae.encode(x)
        adv_x, y_pred_adv, is_adv = search_z_adv(
            vae,
            model,
            x,
            z,
            y_sparse,
            epsilon=kwargs['epsilon'],
            steps=kwargs['steps'],
            num_samples=vae_samples,
            direction='random',
            p=2,
            transform=transform,
            detemintristic=True,
            return_probits=False
            if 'return_probits' not in kwargs else kwargs['return_probits'],
            use_kl=rns_use_kl)

    else:
        raise NotImplementedError(f"{search_method} is not implemented.")

    return adv_x, y_pred_adv, is_adv


def save_into_numpy(root, strategy, data, split, adv_x, y_pred_adv, is_adv):
    np.save(f"{root}/{strategy}_{data}_{split}_counterfactuals.npy", adv_x)
    np.save(f"{root}/{strategy}_{data}_{split}_counterfactuals_pred.npy",
            y_pred_adv)
    np.save(f"{root}/{strategy}_{data}_{split}_counterfactuals_is_adv.npy",
            is_adv)


class ProbitLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ProbitLayer, self).__init__()

    def build(self, input_shape):
        pass

    def call(self, inputs):
        return tf.concat([1.0 - inputs, inputs], axis=1)


if __name__ == "__main__":

    @scriptify
    def script(data,
               training_strategy='baseline,loo,rs',
               use_testset=True,
               use_pred_as_labels=True,
               number_of_models=100,
               search_method='CW',
               epsilon=0.141,
               confidence=1.0,
               stepsize=None,
               steps=1000,
               batch_size=32,
               gpu=0,
               saving_path=None):

        gpus = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
        device = gpus[gpu]

        for device in tf.config.experimental.get_visible_devices('GPU'):
            tf.config.experimental.set_memory_growth(device, True)

        (X_train, y_train), (X_test, y_test), _ = get_data(data)
        model_paths = getPretrainedNetworks(data)

        clamp = [X_train.min(), X_train.max()]
        epsilon *= clamp[1] - clamp[0]

        print(f"Data Range {clamp}")

        if stepsize is None:
            stepsize = 2 * epsilon / steps

        if use_testset:
            data_x = X_test
            data_y_sparse = y_test
        else:
            data_x = X_train
            data_y_sparse = y_train

        training_strategy = training_strategy.split(',')

        for t in training_strategy:
            if t == 'baseline':

                print("> > > > > > Running on baseline network < < < < < <",
                      end='\n',
                      flush=True)

                model = tf.keras.models.load_model(model_paths.baseline)

                if model.output.shape[1] == 1:
                    model = tf.keras.models.Model(model.input,
                                                  ProbitLayer()(model.output))
                    model.compile(loss='categorical_crossentropy',
                                  metrics='acc',
                                  optimizer='adam')

                if use_pred_as_labels:
                    data_y_probits = model.predict(data_x,
                                                   batch_size=batch_size)
                    if len(data_y_sparse.shape
                           ) == 2 and data_y_sparse.shape[1] > 1:
                        data_y_sparse = np.argmax(data_y_probits, -1)
                    else:
                        if len(data_y_probits) > 1:
                            data_y_probits = data_y_probits[:, 0]
                        data_y_sparse = np.where(data_y_probits >= 0.5, 1, 0)

                adv_x, y_pred_adv, is_adv = search_counterfactuals(
                    model,
                    data_x,
                    data_y_sparse,
                    search_method=search_method,
                    epsilon=epsilon,
                    confidence=confidence,
                    stepsize=stepsize,
                    steps=steps,
                    batch_size=batch_size,
                    clamp=clamp)

                if saving_path is not None:
                    split = 'test' if use_testset else 'train'
                    save_into_numpy(saving_path, 'baseline', data, split,
                                    adv_x, y_pred_adv, is_adv)

            elif t == 'loo':

                print(
                    f"> > > > > > Running on {number_of_models} Leave-One-Out networks < < < < < <",
                    end='\n',
                    flush=True)

                for i in range(number_of_models):
                    model = tf.keras.models.load_model(model_paths.loo_list[i])
                    if model.output.shape[1] == 1:
                        model = tf.keras.models.Model(
                            model.input,
                            ProbitLayer()(model.output))
                        model.compile(loss='categorical_crossentropy',
                                      metrics='acc',
                                      optimizer='adam')

                    if use_pred_as_labels:
                        data_y_probits = model.predict(data_x,
                                                       batch_size=batch_size)
                        if len(data_y_sparse.shape
                               ) == 2 and data_y_sparse.shape[1] > 1:
                            data_y_sparse = np.argmax(data_y_probits, -1)
                        else:
                            if len(data_y_probits) > 1:
                                data_y_probits = data_y_probits[:, 0]
                            data_y_sparse = np.where(data_y_probits >= 0.5, 1,
                                                     0)

                    adv_x, y_pred_adv, is_adv = search_counterfactuals(
                        model,
                        data_x,
                        data_y_sparse,
                        search_method=search_method,
                        epsilon=epsilon,
                        confidence=confidence,
                        stepsize=stepsize,
                        steps=steps,
                        batch_size=batch_size,
                        clamp=clamp)
                    if saving_path is not None:
                        split = 'test' if use_testset else 'train'
                        save_into_numpy(saving_path, f'loo[{i}]', data, split,
                                        adv_x, y_pred_adv, is_adv)

            elif t == 'rs':

                print(
                    f"> > > > > > Running on {number_of_models} Random-Seed networks < < < < < <",
                    end='\n',
                    flush=True)

                for i in range(number_of_models):
                    model = tf.keras.models.load_model(model_paths.rs_list[i])
                    if model.output.shape[1] == 1:
                        model = tf.keras.models.Model(
                            model.input,
                            ProbitLayer()(model.output))
                        model.compile(loss='categorical_crossentropy',
                                      metrics='acc',
                                      optimizer='adam')

                    if use_pred_as_labels:
                        data_y_probits = model.predict(data_x,
                                                       batch_size=batch_size)
                        if len(data_y_sparse.shape
                               ) == 2 and data_y_sparse.shape[1] > 1:
                            data_y_sparse = np.argmax(data_y_probits, -1)
                        else:
                            if len(data_y_probits) > 1:
                                data_y_probits = data_y_probits[:, 0]
                            data_y_sparse = np.where(data_y_probits >= 0.5, 1,
                                                     0)

                    adv_x, y_pred_adv, is_adv = search_counterfactuals(
                        model,
                        data_x,
                        data_y_sparse,
                        search_method=search_method,
                        epsilon=epsilon,
                        confidence=confidence,
                        stepsize=stepsize,
                        steps=steps,
                        batch_size=batch_size,
                        clamp=clamp)
                    if saving_path is not None:
                        split = 'test' if use_testset else 'train'
                        save_into_numpy(saving_path, f'rs[{i}]', data, split,
                                        adv_x, y_pred_adv, is_adv)

            else:
                raise ValueError(f"{t} is not a valid training_strategy")
