import os
import numpy as np
import argparse
import time 
from sklearn.metrics import classification_report, accuracy_score
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
# from tensorflow.python.keras.datasets import cifar10, cifar100
from nncf.tensorflow.layers.wrapper import NNCFWrapper
from nncf.tensorflow.quantization.layers import FakeQuantize
from art.estimators.classification import KerasClassifier

tf.compat.v1.disable_eager_execution()

def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, required=True,
                        help='path to all trained models and generated adv examples')
    parser.add_argument('--dataset', type=str, default='cifar100',
                        help='path to all trained models and generated adv examples')
    return parser

args = parser().parse_args()

all_data = {
    'BM' : [],
    'PM' : [],
    'QM': [],
    'advBM' : [],
    'onnxBM':[],
    'onnxPM':[],
    'advPM' : [],
    'advQM' : []
}

def getModelName(s):
    ch1, ch2 = 'classifier-', '-'+args.dataset+'-on'
    return s[s.find(ch1)+11:s.find(ch2)]

def getAttackName(s):
    ch1, ch2 = '-', '-x-test'
    return s[s.find(ch1)+1:s.find(ch2)]

def getPrunedParams(s):
    ch1, ch2 = '_', '-classi'
    return s[s.find(ch1)+1:s.find(ch2)]

def get_testset(dataset):
    if dataset == 'cifar10':
        _, (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
        # x_test = x_test.astype('float32')
        # x_test = x_test / 255
    elif dataset == 'cifar100':
        _, (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
        # x_test = x_test.astype('float32')
    else:
        raise NotImplementedError

    x_test = x_test.astype('float32')
    x_test = x_test / 255.0
    return x_test, y_test

def process_results(predictions, ground_true, inf_time, filesize):
    predictions_classes = np.argmax(predictions, axis=1)
    res_dic = classification_report(ground_true, predictions_classes, output_dict=True)
    acc = res_dic.get("accuracy")
    mac_precision = res_dic.get("macro avg").get("precision")
    mac_recall = res_dic.get("macro avg").get("recall")
    mac_f1_score = res_dic.get("macro avg").get("f1-score")
    support = res_dic.get("macro avg").get("support")
    print('Accuracy : {:.4f}; precision : {:.4f}; recall : {:.4f}; f1-score : {:.4f}; inference time : {:.3f}s'.format(acc, mac_precision, mac_recall, mac_f1_score, inf_time), end='')
    if filesize is not None:
        print('; Size: {}B'.format(filesize/1e6))
    else:
        print()

def get_num_params(model):
    trainableParams = np.sum([np.prod(v.get_shape()) for v in model.trainable_weights])
    nonTrainableParams = np.sum([np.prod(v.get_shape()) for v in model.non_trainable_weights])
    return int(trainableParams + nonTrainableParams)

def evaluate(classifier, x_test, y_test, filesize=None, params=None):
    start_time = time.time()
    predictions = classifier.predict(x_test, verbose = 0)
    elapsed_time = time.time() - start_time
    # process_results(predictions, y_test, elapsed_time, filesize)
    predictions_classes = np.argmax(predictions, axis=1)
    acc = accuracy_score(y_test, predictions_classes)
    print('Accuracy : {:.2f}; Inference time : {:.3f}s '.format(acc*100, elapsed_time), end='')
    if filesize is not None:
        print('; Size: {:.2f}MB; Params: {num:,}'.format(filesize/1e6, num=params))
    else:
        print()
    

def process(path):
    # Iterate over folder structure to get model h5 files and adv examples
    for f in os.listdir(path):
        f_path = os.path.join(path, f)
        if os.path.isdir(f_path) and f != 'Logs':
            # folder containing h5 or npy files
            for d in os.listdir(f_path):
                all_data[f].append(os.path.join(f_path, d))


def eval_BM(x_test, y_test):
    print('#'*20)
    print('BASE MODEL RESULTS ON CLEAN DATA')
    print('#'*20)
    for modelpath in all_data['BM']:
        print()
        print(getModelName(modelpath))
        print('\t', end='')
        classifier = tf.keras.models.load_model(modelpath)
        # classifier = KerasClassifier(model=classifier,clip_values=(0, 1))
        file_size = os.path.getsize(modelpath)
        params = get_num_params(classifier)
        evaluate(classifier, x_test, y_test, file_size, params)

def eval_PM(x_test, y_test):
    print('#'*20)
    print('PRUNED MODEL RESULTS ON CLEAN DATA')
    print('#'*20)
    for modelpath in all_data['PM']:
        print()
        print(getModelName(modelpath) + ' (' + getPrunedParams(modelpath) + ')')
        print('\t', end='')
        classifier = tf.keras.models.load_model(modelpath, compile=False)
        file_size = os.path.getsize(modelpath)
        params = get_num_params(classifier)
        evaluate(classifier, x_test, y_test, file_size, params)

def eval_QM(x_test, y_test):
    print('#'*20)
    print('QUANTIZED MODEL RESULTS ON CLEAN DATA')
    print('#'*20)
    for modelpath in all_data['QM']:
        print()
        print(getModelName(modelpath) + ' (' + getPrunedParams(modelpath) + ')')
        print('\t', end='')
        print(modelpath)
        classifier = tf.keras.models.load_model(modelpath, compile=False,
                                                custom_objects={'FakeQuantize': FakeQuantize, 'NNCFWrapper': NNCFWrapper})
        file_size = os.path.getsize(modelpath)
        params = get_num_params(classifier)
        evaluate(classifier, x_test, y_test, file_size, params)

def eval_advBM_on_BM(y_test):
    print('#'*20)
    print('BASE MODEL RESULTS on advBM DATA')
    print('#'*20)
    for modelpath in all_data['BM']:
        print()
        print('Victim: ', getModelName(modelpath))
        print('-'*7)
        classifier = tf.keras.models.load_model(modelpath)
        for advpath in all_data['advBM']:
            print('  Attack: ' + getAttackName(advpath))
            print('\t', end='')
            # print("advpath: ", advpath)
            try:
                x_adv_test = np.load(advpath)
            except:
                print("ValueError", ValueError)
            evaluate(classifier, x_adv_test[:1000], y_test[:1000])


def eval_advBM_on_PM(y_test):
    print('#'*20)
    print('PRUNED MODEL RESULTS on advBM DATA')
    print('#'*20)
    for modelpath in all_data['PM']:
        print()
        print('Victim: ' + getModelName(modelpath) + ' (' + getPrunedParams(modelpath) + ')')
        print('-'*7)
        classifier = tf.keras.models.load_model(modelpath)
        for advpath in all_data['advBM']:
            print('  Attack: ' + getAttackName(advpath))
            print('\t', end='')
            try:
                x_adv_test = np.load(advpath)
            except:
                print("ValueError", ValueError)
            evaluate(classifier, x_adv_test[:1000], y_test[:1000])


def eval_advPM_on_PM(y_test):
    print('#'*20)
    print('PRUNED MODEL RESULTS on advPM DATA')
    print('#'*20)
    for modelpath in all_data['PM']:
        print()
        print('Victim: ' + getModelName(modelpath) + ' (' + getPrunedParams(modelpath) + ')')
        print('-'*7)
        classifier = tf.keras.models.load_model(modelpath)
        for advpath in all_data['advPM']:
            print('  Attack: ' + getAttackName(advpath))
            print('\t', end='')
            try:
                x_adv_test = np.load(advpath)
            except:
                print("ValueError", ValueError)
            evaluate(classifier, x_adv_test, y_test)

def main():

    if args.dataset != 'cifar10' and args.dataset != 'cifar100':
        raise NotImplementedError
        
    # Traverse over dir tree
    for dataset in os.listdir(args.data_path):
        if dataset == args.dataset:
            process(os.path.join(args.data_path, dataset))
    
    # Chosen dataset
    print('Evaluation on: ' + args.dataset.upper())
    print(os.getcwd())
    print()
    all_data["BM"].sort()
    all_data["PM"].sort()
    all_data["advBM"].sort()
    all_data["advPM"].sort()

    # Get clean data
    x_test, y_test = get_testset(args.dataset)

    # Eval Base Model
    eval_BM(x_test, y_test)
    print('*-'*20)
    print()

    # Eval Pruned Model
    # eval_PM(x_test, y_test)
    # print('*-'*20)
    # print()

    # Eval Quantized Model
    # eval_QM(x_test, y_test)
    # print('*-'*20)
    # print()

    # Eval advBM on BMs
    eval_advBM_on_BM(y_test)
    print('*-'*20)
    print()

    # Eval advBM on PM
    eval_advBM_on_PM(y_test)
    print('*-'*20)
    print()

    # Eval advPM on PM
    # eval_advPM_on_PM(y_test)
    # print('*-'*20)
    # print()

if __name__ == '__main__':
    main()

# Run:    
# python evaluate.py --data_path <path_to_data_folder> --dataset <cifar10 / cifar100 >