# analyze_results.py

import pandas as pd
import os
import numpy as np
import json
import re
import matplotlib.pyplot as plt

import argparse

parser = argparse.ArgumentParser(description='Analyze results from json files')
parser.add_argument('--json_dir', type=str, default='.', help='directory containing json files')
args = parser.parse_args()

#get rid of type 3 fonts
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

df_dict = {}

# iterate through current directory and find and read in all files with .csv with pandas
for file in os.listdir(args.json_dir):
    if file.endswith(".json") and 'success' not in file:
        try:
            df = pd.read_json(file, orient='records')
            key = file.replace('.json','').replace('kWTA_resnet', 'resnet')
            df_dict[key] = df
        except:
            print("Error reading in file: " + file)


# dictionary should have all results in it

# iterate through each key in the dictionary, and determine how many times PGD failed to find an adversarial example when bruteforce or sampled did find one
# keys for each df are the same: idx, label, inference, bruteforce, sampled, pgd, time

def create_loss_plot(df_row, out_dir):
    if df_row['advex_found']:

        # create a plot showing sampled_loss, bruteforce_loss, and pgd_loss over num_iterations
        idx = df_row['idx']
        sample_maxloss = np.array(df_row["sampled_loss"])
        bruteforce_maxloss = np.array(df_row["bruteforce_loss"])
        pgd_loss = np.array(df_row["pgd_loss"])
        num_runs = pgd_loss.shape[0]

        fig, ax = plt.subplots(1, 1, figsize=(10, 5))
        # make dashed horizontal lines at sample_maxloss and bruteforce_maxloss

        # plot each pgd_loss run over num_runs in red
        for i in range(num_runs):
            ax.plot(pgd_loss[i, :], color='red', alpha=0.2)

        ax.axhline(sample_maxloss, color='green', linestyle='--', linewidth=2, label='Sample Max', alpha=0.5)
        ax.axhline(bruteforce_maxloss, color='blue', linestyle='--', linewidth=2, label='Bruteforce Max', alpha=0.5)

        ax.set_xlabel('PGD Iteration')
        ax.set_ylabel('Loss')
        ax.set_title("%s_%d"%(out_dir, idx))
        ax.set_xlim(0, pgd_loss.shape[1])

        ax.legend()
        plt.tight_layout()

        # save the plot
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        plt.savefig(os.path.join(out_dir,str(idx) + '.png'))
        plt.close()


def pgd_found_advex_helper(num_runs=None):
    return lambda df_row: pgd_found_advex(df_row, num_runs)

def pgd_found_advex(df_row, num_runs=None):
    inference_index = np.array(df_row['inference'])
    pgd_list = np.array(df_row['pgd'])
    if num_runs is None:
        num_runs = pgd_list.shape[0]
    if np.any(pgd_list[:num_runs,:inference_index] != 0) or np.any(pgd_list[:num_runs,inference_index+1:] != 0):
        return True
    return False

def sample_found_advex(df_row):
    inference_index = np.array(df_row['inference'])
    sampled_list = np.array(df_row['sampled'])
    if np.any(sampled_list[:inference_index] != 0) or np.any(sampled_list[inference_index+1:] != 0):
        return True
    return False

def bruteforce_found_advex(df_row):
    inference_index = np.array(df_row['inference'])
    bruteforce_list = np.array(df_row['bruteforce'])
    if np.any(bruteforce_list[:inference_index] != 0) or np.any(bruteforce_list[inference_index+1:] != 0):
        return True
    return False

def pgd_failed(df_row):
    # PGD failed if pgd did not find an adversarial example when bruteforce or sampled did find one
    if (bruteforce_found_advex(df_row) or sample_found_advex(df_row)) and not pgd_found_advex(df_row):
        return True
    return False

def sort_models_by_dim(x):
    return int(x.split('dims_')[0].split('_')[-1])

def sort_models_by_steps(x):
    if "steps" in x:
        return int(x.split('steps')[0].split('_')[-1])
    else:
        return -1
    

def sort_models_by_type(x):
    total_num = 0
    # dims = int(x.split('dims_')[0].split('_')[-1])
    # total_num += dims
    if 'infty' in x:
        total_num += 3000
    if 'kWTA_spresnet' in x:
        total_num += 10000 + 1000*float(x.split('_')[2])
    if 'resnet18' in x:
        total_num += 500
    if '_adv_' in x:
        total_num += 1000
    if 'diff' in x:
        total_num += 50000
    # if 'steps' in x:
    #     total_num += 10000 + 100 * int(x.split('steps')[0].split('_')[-1])

    return total_num

def sort_models(x):
    return (sort_models_by_type(x), sort_models_by_dim(x), sort_models_by_steps(x))

def loss_plot_helper(out_dir):
    return lambda df_row: create_loss_plot(df_row, out_dir)

outfile = open('overall_analysis.csv', 'w')
outfile.write('model, bruteforce, sampled, pgd1, pgd10, pgd20, bruteforce_miss, sampled_miss, pgd1_miss, pgd10_miss, pgd20_miss, pgd1_mcnemar, pgd10_mcnemar, pgd20_mcnemar, bruteforce_hit, random_hit_gridsweep_count, pgd1_hit_gridsweep_count, pgd10_hit_gridsweep_count, pgd20_hit_gridsweep_count\n')

from scipy import stats

for key in sorted(df_dict.keys(), key=sort_models):
    # try:
    df = df_dict[key]
    inference_index = df['inference']
    df['sample_found_advex'] = df.apply(sample_found_advex, axis=1)
    df['bruteforce_found_advex'] = df.apply(bruteforce_found_advex, axis=1)
    df['pgd1_found_advex'] = df.apply(pgd_found_advex_helper(1), axis=1)
    df['pgd10_found_advex'] = df.apply(pgd_found_advex_helper(10), axis=1)
    df['pgd20_found_advex'] = df.apply(pgd_found_advex_helper(20), axis=1)
    df['pgd_found_advex'] = df.apply(pgd_found_advex_helper(), axis=1)
    df['advex_found'] = df.apply(lambda x: x['bruteforce_found_advex'] or x['sample_found_advex'] or x['pgd_found_advex'], axis=1)

    bruteforcetrue_randomtrue = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and x['sample_found_advex'], axis=1))
    
    # create field where bruteforce found advex and pgd1 found advex
    bruteforcetrue_pgd1true = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and x['pgd1_found_advex'], axis=1))
    bruteforcetrue_pgd1false = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd1_found_advex'], axis=1))
    bruteforcetrue_pgd1false_mean = np.mean(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd1_found_advex'], axis=1))
    bruteforcefalse_pgd1true = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and x['pgd1_found_advex'], axis=1))
    bruteforcefalse_pgd1false = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and not x['pgd1_found_advex'], axis=1))

    # calculate mcnemar's test https://en.wikipedia.org/wiki/McNemar%27s_test for pgd1
    if bruteforcetrue_pgd1false == 0 and bruteforcefalse_pgd1true == 0:
        pgd1_mcnemar_p_value = -1.
    else:
        pgd1_mcnemar_test_statistic = (bruteforcetrue_pgd1false - bruteforcefalse_pgd1true)**2 / (bruteforcetrue_pgd1false + bruteforcefalse_pgd1true)
        pgd1_mcnemar_p_value = 1. - stats.chi2.cdf(pgd1_mcnemar_test_statistic, 1)


    # create field where bruteforce found advex and pgd10 found advex
    bruteforcetrue_pgd10true = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and x['pgd10_found_advex'], axis=1))
    bruteforcetrue_pgd10false = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd10_found_advex'], axis=1))
    bruteforcetrue_pgd10false_mean = np.mean(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd10_found_advex'], axis=1))
    bruteforcefalse_pgd10true = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and x['pgd10_found_advex'], axis=1))
    bruteforcefalse_pgd10false = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and not x['pgd10_found_advex'], axis=1))

    # calculate mcnemar's test https://en.wikipedia.org/wiki/McNemar%27s_test for pgd10
    if bruteforcetrue_pgd10false == 0 and bruteforcefalse_pgd10true == 0:
        pgd10_mcnemar_p_value = -1.
    else:
        pgd10_mcnemar_test_statistic = (bruteforcetrue_pgd10false - bruteforcefalse_pgd10true)**2 / (bruteforcetrue_pgd10false + bruteforcefalse_pgd10true)
        pgd10_mcnemar_p_value = 1. - stats.chi2.cdf(pgd10_mcnemar_test_statistic, 1)

    # create field where bruteforce found advex and pgd20 found advex
    bruteforcetrue_pgd20true = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and x['pgd20_found_advex'], axis=1))
    bruteforcetrue_pgd20false = np.sum(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd20_found_advex'], axis=1))
    bruteforcetrue_pgd20false_mean = np.mean(df.apply(lambda x: x['bruteforce_found_advex'] and not x['pgd20_found_advex'], axis=1))
    bruteforcefalse_pgd20true = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and x['pgd20_found_advex'], axis=1))
    bruteforcefalse_pgd20false = np.sum(df.apply(lambda x: not x['bruteforce_found_advex'] and not x['pgd20_found_advex'], axis=1))

    # calculate mcnemar's test https://en.wikipedia.org/wiki/McNemar%27s_test for pgd20
    if bruteforcetrue_pgd20false == 0 and bruteforcefalse_pgd20true == 0:
        pgd20_mcnemar_p_value = -1.
    else:
        pgd20_mcnemar_test_statistic = (bruteforcetrue_pgd20false - bruteforcefalse_pgd20true)**2 / (bruteforcetrue_pgd20false + bruteforcefalse_pgd20true)
        pgd20_mcnemar_p_value = 1. - stats.chi2.cdf(pgd20_mcnemar_test_statistic, 1)

    
    # calculate natural accuracy
    natural_accuracy = np.mean(df['inference'] == df['label'])

    df['pgd_failed'] = df.apply(pgd_failed, axis=1)

    df_dict[key] = df
    print("%s:\tnat_acc: %.2f\tbruteforce: %.2f, sampled: %.2f, pgd1: %.2f, pgd10: %.2f,pgd20: %.2f, bruteforce_miss: %.2f, sampled_miss: %.2f, pgd1_miss: %.2f, pgd10_miss: %.2f, pgd20_miss: %.2f, pgd1_mcnemar_p_value: %.2f, pgd10_mcnemar_p_value: %.2f, pgd20_mcnemar_p_value: %.2f, pgd1_hit_gridsweep_count: %d/%d, pgd10_hit_gridsweep_count: %d/%d, pgd20_hit_gridsweep_count: %d/%d," % \
        (key, natural_accuracy, np.mean(df['bruteforce_found_advex']), np.mean(df['sample_found_advex']), np.mean(df['pgd1_found_advex']), \
        np.mean(df['pgd10_found_advex']), np.mean(df['pgd20_found_advex']),
            1.-np.sum(df['bruteforce_found_advex'])/np.sum(df['advex_found']), \
            1.-np.sum(df['sample_found_advex'])/np.sum(df['advex_found']), \
            1.-np.sum(df['pgd1_found_advex'])/np.sum(df['advex_found']), \
            1.-np.sum(df['pgd10_found_advex'])/np.sum(df['advex_found']), \
            1.-np.sum(df['pgd20_found_advex'])/np.sum(df['advex_found']), \
            pgd1_mcnemar_p_value, pgd10_mcnemar_p_value, pgd20_mcnemar_p_value,\
            bruteforcetrue_pgd1true, bruteforcetrue_pgd1false + bruteforcetrue_pgd1true, \
            bruteforcetrue_pgd10true, bruteforcetrue_pgd10false + bruteforcetrue_pgd10true, \
            bruteforcetrue_pgd20true, bruteforcetrue_pgd20false + bruteforcetrue_pgd20true, \
            ))
    outfile.write("%s, %.3f, %.3f, %.3f, %.3f, %.3f, %.3f, %.3f, %.3f, %.3f, %.3f, %.5f, %.5f, %.5f, %d, %d, %d, %d, %d\n"% \
        (key, np.mean(df['bruteforce_found_advex']), \
        np.mean(df['sample_found_advex']), \
        np.mean(df['pgd1_found_advex']), \
        np.mean(df['pgd10_found_advex']), \
        np.mean(df['pgd20_found_advex']), \
        1.-np.sum(df['bruteforce_found_advex'])/np.sum(df['advex_found']), \
        1.-np.sum(df['sample_found_advex'])/np.sum(df['advex_found']), \
        1.-np.sum(df['pgd1_found_advex'])/np.sum(df['advex_found']), \
        1.-np.sum(df['pgd10_found_advex'])/np.sum(df['advex_found']), \
        1.-np.sum(df['pgd20_found_advex'])/np.sum(df['advex_found']), \
        pgd1_mcnemar_p_value, pgd10_mcnemar_p_value, pgd20_mcnemar_p_value, \
        bruteforcetrue_pgd20false + bruteforcetrue_pgd20true, \
        bruteforcetrue_randomtrue, \
        bruteforcetrue_pgd1true, \
        bruteforcetrue_pgd10true, \
        bruteforcetrue_pgd20true
        ))
    
outfile.close()