import json
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Setting
# network = 'wrn34-10'
# dataset = 'cifar100'
# network = 'resnet50'
# dataset = 'imagenet'

network = 'joint'
dataset = 'cifar100'
nat = True # True: With natural, False: Without natural
basedir = os.getcwd()
save_dir = os.path.join(basedir, 'figure/discussion', f'{network}_{dataset}')

if not os.path.exists(save_dir):
    os.makedirs(save_dir, exist_ok=True)

### path setting
shap_std_path = './results/shap_statistics/shap_statistics_total_std.json'
shap_mean_path = './results/shap_statistics/shap_statistics_total_mean.json'

acc_path = './results/class_acc'
if network == 'joint':
    acc_dir1 = f'{acc_path}/wrn28-10/{dataset}'
    acc_paths1 = os.listdir(acc_dir1)
    acc_path_dicts1 = {'baseline_val':[], 'cutmix_val':[], 'mixup_val':[]}
    acc_dir2 = f'{acc_path}/wrn34-10/{dataset}'
    acc_paths2 = os.listdir(acc_dir2)
    acc_path_dicts2 = {'baseline_val':[], 'cutmix_val':[], 'mixup_val':[]}
else:
    acc_dir1 = f'{acc_path}/{network}/{dataset}'
    acc_paths1 = os.listdir(acc_dir1)
    acc_path_dicts1 = {'baseline_val':[], 'cutmix_val':[], 'mixup_val':[]}
method = {
    'natural': {
        "baseline":["Baseline", "dimgray"],
        "cutmix":["CutMix", "darkorange"],
        "mixup":["Mixup", "deepskyblue"]},
    'wrn28-10_cifar10': {
        "Xu2023Exploring_WRN-28-10":["DyART", "rosybrown"],
        "Wang2020Improving":["MART", "brown"],
        "Pang2022Robustness_WRN28_10":["SCORE", "red"],
        "Wang2023Better_WRN-28-10":["diffusion-augmented AT", "orange"],
        "Rade2021Helper_ddpm":["HAT", "gold"],
        "Wu2020Adversarial_extra":["AWP", "olive"],
        "Carmon2019Unlabeled":["RST", "green"],
        "Zhang2020Geometry":["GAIRAT", "cyan"],
        "Sridhar2021Robust":["RLPE", "blue"],
        "Gowal2021Improving_28_10_ddpm_100m":["IRUGD", "purple"],
        "Sehwag2020Hydra":["HYDRA", "magenta"]},
    'wrn28-10_cifar100': {
        "Pang2022Robustness_WRN28_10":["SCORE", "red"],
        "Wang2023Better_WRN-28-10":["diffusion-augmented AT", "orange"],
        "Cui2023Decoupled_WRN-28-10":["IKL", "yellow"],
        "Rebuffi2021Fixing_28_10_cutmix_ddpm":["FDA", "green"]},
    'wrn34-10_cifar10': {
        "Zhang2020Attacks":["FAT", "red"],
        "Chen2024Data_WRN_34_10":["DefEAT", "orange"],
        "Addepalli2022Efficient_WRN_34_10":["DAJAT", "yellow"],
        "Addepalli2021Towards_WRN34":["OA-AT", "green"],
        "Huang2020Self":["SAT", "cyan"],
        "Rade2021Helper_extra":["HAT", "blue"],
        "Wu2020Adversarial":["AWP", "purple"],
        "Cui2020Learnable_34_10":["LBGAT", "magenta"],
        "Zhang2019You":["YOPO", "brown"],
        "Zhang2019Theoretically":["TRADES", "olive"]},
    'wrn34-10_cifar100': {
        "Cui2020Learnable_34_10_LBGAT6":["LBGAT", "red"],
        "Cui2020Learnable_34_10_LBGAT9_eps_8_255":["LBGAT", "orange"],
        "Addepalli2021Towards_WRN34":["OA-AT", "yellow"],
        "Cui2023Decoupled_WRN-34-10":["IKL", "green"],
        "Cui2023Decoupled_WRN-34-10_autoaug":["IKL", "cyan"],
        "Sehwag2021Proxy":["Proxy", "blue"],
        "Jia2022LAS-AT_34_10":["LAS-AT", "purple"],
        "Chen2021LTD_WRN34_10":["LTD", "magenta"],
        "Addepalli2022Efficient_WRN_34_10":["DAJAT", "brown"]},
    'resnet50_imagenet': {
        "Salman2020Do_R50":["Salman et al.", "red"],
        "Engstrom2019Robustness":["Engstrom et al.", "orange"],
        "Wong2020Fast":["Cheap-AT", "yellow"]},
    }

if network == 'joint':
    att = ['attack_no', 'attack_fgsm']
    for p in acc_paths1:
        paths = []
        for attack in att:
            paths.append(f'{acc_dir1}/{p}/{attack}.json')
        acc_path_dicts1[p] = paths

    acc_dicts1 = {}
    for key in acc_path_dicts1.keys():
        acc_list = []
        for path in acc_path_dicts1[key]:
            acc_json = json.load(open(path, 'r'))
            if path.split('/')[-1] == 'robustbench_acc.json':
                acc_list.append(float(acc_json['clean']))
                acc_list.append(float(acc_json['robust']))
            else:
                acc_list.append(float(acc_json['total acc']))
        acc_dicts1[key] = acc_list

    for p in acc_paths2:
        paths = []
        for attack in att:
            paths.append(f'{acc_dir2}/{p}/{attack}.json')
        
        acc_path_dicts2[p] = paths

    acc_dicts2 = {}
    for key in acc_path_dicts2.keys():
        acc_list = []
        for path in acc_path_dicts2[key]:
            acc_json = json.load(open(path, 'r'))
            if path.split('/')[-1] == 'robustbench_acc.json':
                acc_list.append(float(acc_json['clean']))
                acc_list.append(float(acc_json['robust']))
            else:
                acc_list.append(float(acc_json['total acc']))
        acc_dicts2[key] = acc_list
    

    # shap json load (std, mean)
    shap_std_json = json.load(open(shap_std_path, 'r'))
    shap_mean_json = json.load(open(shap_mean_path, 'r'))
    results = {'attack_fgsm':[], 'shap_std':[], 'shap_mean':[]}

    no, fgsm, pgd, cw, autoattack = [], [], [], [], []
    for key in acc_dicts1.keys():
        no = acc_dicts1[key][0]/100
        fgsm = acc_dicts1[key][1]/100
        results['attack_fgsm'].append(1-fgsm/no)
        
        shap_key = f'wrn28-10_{dataset}_{key}'
        shap_std = float(shap_std_json[shap_key])
        shap_mean = float(shap_mean_json[shap_key])
        results['shap_std'].append(shap_std)
        results['shap_mean'].append(shap_mean)

    # make dataframe
    acc_dicts_key = list(acc_dicts1.keys())
    methods = []
    for i in acc_dicts_key:
        if i in ['baseline_val', 'cutmix_val', 'mixup_val']:
            methods.append(method['natural'][i.replace('_val', '')][0])
        else:
            methods.append(method['wrn28-10'+'_'+dataset][i.replace('_val', '')][0])

    df1 = pd.DataFrame(results, index=methods)

    results = {'attack_fgsm':[], 'shap_std':[], 'shap_mean':[]}
    no, fgsm, pgd, cw, autoattack = [], [], [], [], []
    for key in acc_dicts2.keys():
        no = acc_dicts2[key][0]/100
        fgsm = acc_dicts2[key][1]/100
        results['attack_fgsm'].append(1-fgsm/no)
        
        shap_key = f'wrn34-10_{dataset}_{key}'
        shap_std = float(shap_std_json[shap_key])
        shap_mean = float(shap_mean_json[shap_key])
        results['shap_std'].append(shap_std)
        results['shap_mean'].append(shap_mean)

    # make dataframe
    acc_dicts_key = list(acc_dicts2.keys())
    methods = []
    for i in acc_dicts_key:
        if i in ['baseline_val', 'cutmix_val', 'mixup_val']:
            methods.append(method['natural'][i.replace('_val', '')][0])
        else:
            methods.append(method['wrn34-10'+'_'+dataset][i.replace('_val', '')][0])

    df2 = pd.DataFrame(results, index=methods)
    standard = ['attack_fgsm',]
    for sta in ['std', 'mean']:
        if sta == 'std':
            y1 = df1['shap_std']
            y2 = df2['shap_std']
        elif sta == 'mean':
            y1 = df1['shap_mean']
            y2 = df2['shap_mean']

        if nat == False:
            for met in ['Baseline', 'CutMix', 'Mixup']:
                y1 = y1.drop(met)
                y2 = y2.drop(met)

        graph_fig_dir = f'{save_dir}/{sta}'
        if not os.path.exists(graph_fig_dir):
            os.makedirs(graph_fig_dir, exist_ok=True)
        for s in standard:
            
            save_fig_path = f'{graph_fig_dir}/{s}_plot.svg'
            x1 = df1[s]
            x2 = df2[s]
            if nat == False:
                for met in ['Baseline', 'CutMix', 'Mixup']:
                    x1 = x1.drop(met)
                    x2 = x2.drop(met)
            index1 = list(x1.index)
            index2 = list(x2.index)

            x_min = min(x1.min(),x2.min()) - 0.01
            x_max = max(x1.max(),x2.max()) + 0.01
            if x_min < 0:
                x_min = 0.0
            if x_max > 1:
                x_max = 1.0

            labels = [f'{index1[i]}_28' for i in range(len(index1))]+[f'{index2[i]}_34' for i in range(len(index2))]
            
            plt.figure(figsize=(12, 8))
            for i in range(len(index1)):
                temp_method = index1[i]
                if temp_method in ['Baseline', 'CutMix', 'Mixup']:
                    color = method['natural'][temp_method.lower()][1]
                else:
                    temp_list = list(method['wrn28-10'+'_'+dataset].values())
                    for j in temp_list:
                        if temp_method == j[0]:
                            color = j[1]
                plt.scatter(x1[i], y1[i], label=temp_method+'_28', marker='o', color=color)
            colors = ['red', 'green', 'cyan', 'purple', 'magenta', 'brown', 'olive']
            for i in range(len(index2)):
                temp_method = index2[i]
                if temp_method in ['Baseline', 'CutMix', 'Mixup']:
                    temp_list = list(method['wrn34-10'+'_'+dataset].values())
                    color = colors[i]
                else:
                    temp_list = list(method['wrn34-10'+'_'+dataset].values())
                    for j in temp_list:
                        if temp_method == j[0]:
                            color = j[1]
                plt.scatter(x2[i], y2[i], label=temp_method+'_34', marker='o', color=color)                          
            plt.xlim(x_min, x_max)

            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.legend(loc='lower left', bbox_to_anchor=(1.0,0.0), fontsize=12)
            plt.xlabel('Accuracy drop (%)', fontsize=20)
            if sta == 'std':
                plt.ylabel('SD of Shapley Value', fontsize=20)
            elif sta == 'mean':
                plt.ylabel('Mean of Shapley Value', fontsize=20)
            plt.tight_layout()
            plt.savefig(save_fig_path)
            plt.close()

else:
    att = ['attack_no', 'attack_fgsm']
    for p in acc_paths1:
        paths = []
        for attack in att:
            paths.append(f'{acc_dir1}/{p}/{attack}.json')
        acc_path_dicts1[p] = paths

    acc_dicts1 = {}
    for key in acc_path_dicts1.keys():
        acc_list = []
        for path in acc_path_dicts1[key]:
            acc_json = json.load(open(path, 'r'))
            if path.split('/')[-1] == 'robustbench_acc.json':
                acc_list.append(float(acc_json['clean']))
                acc_list.append(float(acc_json['robust']))
            else:
                acc_list.append(float(acc_json['total acc']))
        acc_dicts1[key] = acc_list


    shap_std_json = json.load(open(shap_std_path, 'r'))
    shap_mean_json = json.load(open(shap_mean_path, 'r'))
    results = {'attack_fgsm':[], 'shap_std':[], 'shap_mean':[]}

    no, fgsm, pgd, cw, autoattack = [], [], [], [], []
    for key in acc_dicts1.keys():
        no = acc_dicts1[key][0]/100
        fgsm = acc_dicts1[key][1]/100
        results['attack_fgsm'].append(1-fgsm/no)
        
        shap_key = f'wrn28-10_{dataset}_{key}'
        shap_std = float(shap_std_json[shap_key])
        shap_mean = float(shap_mean_json[shap_key])
        results['shap_std'].append(shap_std)
        results['shap_mean'].append(shap_mean)

    # make dataframe
    acc_dicts_key = list(acc_dicts1.keys())
    methods = []
    for i in acc_dicts_key:
        if i in ['baseline_val', 'cutmix_val', 'mixup_val']:
            methods.append(method['natural'][i.replace('_val', '')][0])
        else:
            methods.append(method['wrn28-10'+'_'+dataset][i.replace('_val', '')][0])

    df1 = pd.DataFrame(results, index=methods)
    standard = ['attack_fgsm',]
    for sta in ['std', 'mean']:
        if sta == 'std':
            y1 = df1['shap_std']
        elif sta == 'mean':
            y1 = df1['shap_mean']

        if nat == False:
            for met in ['Baseline', 'CutMix', 'Mixup']:
                y1 = y1.drop(met)

        graph_fig_dir = f'{save_dir}/{sta}'
        if not os.path.exists(graph_fig_dir):
            os.makedirs(graph_fig_dir, exist_ok=True)
        for s in standard:
            
            save_fig_path = f'{graph_fig_dir}/{s}_plot.svg'
            x1 = df1[s]
            if nat == False:
                for met in ['Baseline', 'CutMix', 'Mixup']:
                    x1 = x1.drop(met)
            index1 = list(x1.index)
            x_min = x1.min() - 0.01
            x_max = x1.max() + 0.01
            if x_min < 0:
                x_min = 0.0
            if x_max > 1:
                x_max = 1.0
            labels = [f'{index1[i]}' for i in range(len(index1))]
            
            plt.figure(figsize=(12, 8))
            for i in range(len(index1)):
                temp_method = index1[i]
                if temp_method in ['Baseline', 'CutMix', 'Mixup']:
                    color = method['natural'][temp_method.lower()][1]
                else:
                    temp_list = list(method['wrn28-10'+'_'+dataset].values())
                    for j in temp_list:
                        if temp_method == j[0]:
                            color = j[1]
                plt.scatter(x1[i], y1[i], label=temp_method+'_28', marker='o', color=color)
            plt.xlim(x_min, x_max)

            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.legend(loc='lower left', bbox_to_anchor=(1.0,0.0), fontsize=12)
            plt.xlabel('Accuracy drop (%)', fontsize=20)
            if sta == 'std':
                plt.ylabel('SD of Shapley Value', fontsize=20)
            elif sta == 'mean':
                plt.ylabel('Mean of Shapley Value', fontsize=20)
            plt.tight_layout()
            plt.savefig(save_fig_path)
            plt.close()