import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image

import glob, os

# for measuring similarity of weights
from sklearn.metrics.pairwise import cosine_similarity
from cleverhans.tf2.attacks.carlini_wagner_l2 import carlini_wagner_l2
from cleverhans.tf2.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.tf2.attacks.projected_gradient_descent import projected_gradient_descent
from cleverhans.tf2.attacks.momentum_iterative_method import momentum_iterative_method
from cleverhans.tf2.attacks.spsa import spsa
import pandas as pd

import warnings


def create_dataset(main_folder_path, img_size, model, preprocess_function, clip_min,
                   clip_max, kind='FGSM', vulnerability_tab=None):
    X_adversarial = []
    X_original = []
    y = []
    df1 = pd.read_csv('dev_dataset.csv')
    os.chdir(".")
    for file in glob.glob(main_folder_path + "/*.png"):
        img_label = file.split('/')[1].split('.')[0]
        label = (df1.loc[df1.ImageId == img_label].TrueLabel - 1).values

        img = image.load_img(file, target_size=img_size)
        img_array = image.img_to_array(img)
        img_array = preprocess_function(img_array)
        img_array_original = np.expand_dims(img_array, axis=0)
        target_label = vulnerability_tab[label]

        if kind == "CW":
            target_label = tf.keras.utils.to_categorical(target_label, num_classes=1000)
            img_array_adversarial = carlini_wagner_l2(model_fn=model, x=img_array_original, y=target_label,
                                                      targeted=True, clip_min=clip_min, max_iterations=30)
        elif kind == "FGM":
            img_array_adversarial = fast_gradient_method(model_fn=model, x=img_array_original, eps=0.1, norm=np.inf,
                                                         y=target_label, targeted=True)
        elif kind == "PGD":
            img_array_adversarial = projected_gradient_descent(model_fn=model, x=img_array_original, y=target_label,
                                                               targeted=True, eps=0.1, eps_iter=0.01, nb_iter=40,
                                                               norm=np.inf)
        elif kind == "Momentum":
            img_array_adversarial = momentum_iterative_method(model_fn=model, x=img_array_original, targeted=True,
                                                              y=target_label, eps=0.05, eps_iter=0.06, nb_iter=10,
                                                              norm=np.inf)
        elif kind == "SPSA":
            img_array_adversarial = spsa(model_fn=model, x=tf.convert_to_tensor(img_array_original), targeted=True,
                                         y=target_label, eps=0.1, nb_iter=100,
                                         spsa_samples=128,  # Number of samples to estimate the gradient
                                         spsa_iters=1, clip_min=clip_min, clip_max=clip_max)

        X_adversarial.append(img_array_adversarial[0])
        y.append(label)
        X_original.append(img_array)
    X_adversarial = np.asarray(X_adversarial)
    y = np.asarray(y)
    X_original = np.asarray(X_original)
    return X_original, X_adversarial, y


# these functions can be used to evaluate the performance of adversarial attacks - in terms of the impact on model's accuracy,
# as well as on the consistemncy of models predictions (for top k classes)

def test_adversarial_performance(model, X_original, X_adversarial, y, vulnerability_tab):
    from sklearn.metrics import accuracy_score
    y_pred_original = model.predict(X_original)
    y_pred_original = np.argmax(y_pred_original, axis=1)
    y_pred_adversarial = model.predict(X_adversarial)
    y_pred_adversarial = np.argmax(y_pred_adversarial, axis=1)

    y_target = np.array([vulnerability_tab[y[i]] for i in range(len(y))])

    accuracy_adversarial = accuracy_score(y, y_pred_adversarial)
    accuracy_original = accuracy_score(y, y_pred_original)
    ASR = 1 - accuracy_score(y_pred_original, y_pred_adversarial)
    TSR = accuracy_score(y_target, y_pred_adversarial)

    class_a = np.moveaxis(model.layers[-1].get_weights()[0], 0, -1)
    class_b = np.moveaxis(model.layers[-1].get_weights()[0], 0, -1)

    # tab with cosine similairties of classes
    cs_tab = cosine_similarity(class_a, class_b)

    sorted_cs_args = np.argsort(-cs_tab, axis=1)  # const - mozna raz
    y_reshaped = np.reshape(y, (y.shape[0]))
    harmfulness_metric = np.mean(
        np.argmax((sorted_cs_args[y_reshaped] == np.expand_dims(y_pred_adversarial, axis=1)).astype(int), axis=1) / 999)

    deterioration_rate = np.mean(
        ((y_reshaped == y_pred_original) & (y_pred_adversarial != y_pred_original)).astype(int))
    return accuracy_original, accuracy_adversarial, accuracy_original - accuracy_adversarial, ASR, deterioration_rate, harmfulness_metric, TSR

def single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore")
        X_original, X_adversarial, y = create_dataset(main_folder_path=dataset_name,
                                                      img_size=size_img,
                                                      model=model,
                                                      preprocess_function=preprocess_function,
                                                      clip_min=-1,
                                                      clip_max=1,
                                                      kind=kind,
                                                      vulnerability_tab=vulnerability_tab)
        accuracy_original, accuracy_adversarial, accuracy_diff, ASR, deterioration_rate, DM, TSR = test_adversarial_performance(
            model, X_original, X_adversarial, y, vulnerability_tab)
    print(kind)
    print(
        f"accuracy_original: {accuracy_original}, \naccuracy_adversarial: {accuracy_adversarial}, \naccuracy_diff: {accuracy_diff}, \nASR: {ASR}, \ndeterioration_rate: {deterioration_rate}, \nDM: {DM}, \nTRS:{TSR}")
    results_row = {
        "attack_name": kind,
        "accuracy_original": accuracy_original,
        "accuracy_adversarial": accuracy_adversarial,
        "accuracy_diff": accuracy_diff,
        "ASR": ASR,
        "deterioration_rate": deterioration_rate,
        "DM": DM,
        "TSR": TSR
    }
    new_row_df = pd.DataFrame([results_row])
    return new_row_df

def gather_results(model, dataset_name, size_img, vulnerability_tab, preprocess_function):
    columns = ["attack_name", "accuracy_original", "accuracy_adversarial", "accuracy_diff", "ASR", "deterioration_rate",
               "DM", "TSR"]
    df = pd.DataFrame(columns=columns)

    kind = 'CW'
    new_row_df = single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function)
    df = pd.concat([df, new_row_df], ignore_index=True)

    kind = 'FGM'
    new_row_df = single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function)
    df = pd.concat([df, new_row_df], ignore_index=True)

    kind = 'PGD'
    new_row_df = single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function)
    df = pd.concat([df, new_row_df], ignore_index=True)

    kind = 'Momentum'
    new_row_df = single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function)
    df = pd.concat([df, new_row_df], ignore_index=True)

    kind = 'SPSA'

    new_row_df = single_attack_results_generation(kind, model, dataset_name, size_img, vulnerability_tab, preprocess_function)
    df = pd.concat([df, new_row_df], ignore_index=True)

    return df

def results_saving(model_name, dataset_name="DEV", size_img=(224, 224), similarity_source="LLAMA", variant="MS"):

    if similarity_source == "LLAMA":
        similarity_matrix = np.load("similarity_matrix_llama.npy")
    elif similarity_source == "CLIP":
        similarity_matrix = np.load("similarity_matrix_CLIP.npy")
    elif similarity_source == "BERT":
        similarity_matrix = np.load("similarity_matrix_BERT.npy")
    elif similarity_source == "WUP":
        similarity_matrix = np.load("WUP_sim_ref.npy")
    else:
        print("Wrong source model name provided, use LLAMA/CLIP/BERT")

    if variant == "MS":
        vulnerability_tab = np.argsort(-similarity_matrix[:])[:, 1]
    elif variant == "LS":
        vulnerability_tab = np.argsort(-similarity_matrix[:])[:, -1]
    else:
        print("Wrong attack variant name provided, use MS or LS")

    if model_name == "MobileNetV2":
        from tensorflow.keras.applications import MobileNetV2, mobilenet_v2
        from tensorflow.keras.applications.mobilenet_v2 import preprocess_input, decode_predictions
        size_img = (224, 224)
        model = MobileNetV2(weights='imagenet')
    elif model_name == "EfficientNetV2B0":
        from tensorflow.keras.applications import EfficientNetV2B0
        from tensorflow.keras.applications.resnet_v2 import preprocess_input, decode_predictions
        size_img = (224, 224)
        model = EfficientNetV2B0(weights='imagenet', include_preprocessing=False)
    elif model_name == "ResNet50V2":
        from tensorflow.keras.applications import ResNet50V2
        from tensorflow.keras.applications.resnet_v2 import preprocess_input, decode_predictions
        size_img = (224, 224)
        model = ResNet50V2(weights='imagenet')
    else:
        print("Wrong model name provided, use MobileNetV2, EfficientNetV2B0 or ResNet50V2")

    df = gather_results(model, dataset_name, size_img, vulnerability_tab, preprocess_input)
    csv_filename = f"{similarity_source}_{model.name}_{variant}.csv"
    df.to_csv(csv_filename, index=False)

import argparse



def main():
    parser = argparse.ArgumentParser(description="My script with command-line args")

    parser.add_argument('--model', type=str, required=True, help='Tested model')
    parser.add_argument('--similarity_source', type=str, required=True, help='Name of the simialrity source model')
    parser.add_argument('--variant', type=str, required=True, help='Variant name')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset root folder name')

    args = parser.parse_args()
    results_saving(model_name=args.model, dataset_name=args.dataset, size_img=(224, 224), similarity_source=args.similarity_source, variant=args.variant)

if __name__ == "__main__":
    main()
