import torch
import numpy as np
import matplotlib.pyplot as plt

common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                      'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                      'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']

def get_transfer_attack_results(corruption, level):
    results = []
    # Adv S
    data = torch.load('./results/advS/{}/accs_lvls.pt'.format(corruption))
    results.append(data['adv_accs'][level-1])
    # DANN + Adv S
    directory = './results/DANN/advS/{}'.format(corruption)
    data = np.load(directory+'/accs_level{}.npy'.format(level))
    results.append(data[-1])
    # Adv T 
    data = torch.load('./results/advT/{}/accs_lvls.pt'.format(corruption))
    results.append(data['adv_accs'][level-1])
    # DANN + Adv T 
    directory = './results/DANN/advT/{}'.format(corruption)
    data = np.load(directory+'/accs_level{}.npy'.format(level))
    results.append(data[-1])
    # DANN
    directory = './results/DANN/none/{}'.format(corruption)
    data = np.load(directory+'/accs_level{}.npy'.format(level))
    results.append(data[-1])
    return results 

    

if __name__ == "__main__":
    colors = ['b', 'y', 'g', 'r', 'k']
    labels = ['AdvS', 'DANN+AdvS', 'AdvT', 'DANN+AdvT', 'DANN']
    for level in np.arange(5)+1: 
        x = np.arange(len(common_corruptions))*1.5
        width = 0.2
        plt.figure(figsize=(23,6))
        for idx, corruption in enumerate(common_corruptions):
            results = get_transfer_attack_results(corruption, level)
            print(results)
            for i,result in enumerate(results): 
                plt.bar(x[idx] + (i-4.5)*width, result, width, color = colors[i], label=labels[i])
            if idx == 0: 
                plt.legend(loc = 'upper right')
        plt.xticks(x - width*2.5, common_corruptions)
        plt.xticks(rotation=45)
        plt.ylabel("Acc")
        plt.title("Transfer Attack for CIFAR10c (Level of Severity = {})".format(level))
        plt.tight_layout()
        plt.savefig('plot/figures/transfer_attack_cifar10c_lvl{}.png'.format(level))
        plt.close()
