import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_palette('colorblind') 
from matplotlib.ticker import MaxNLocator

n_iter= 14
dir_names  = ['results/DANN_cifar10_cifar10c/DANN_FPA_RI', 
             'results/DANN_cifar10_cifar10c/DANN_FPA_RI140739', 
             'results/cifar10c_fog_adv_none_gn/DANN_FPA_RI']
dir_names_homo = ['results/cifar10_none_gn/DANN_FPA_RI',
            'results/cifar10_none_gn/DANN_FPA_RI140739',
            'results/cifar10_adv_pgd7/DANN_FPA_RI'
             ] 

label = ['Rand Init', 'Private Rand Init', 'AdvT Init']
ax = plt.figure().gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

def process_result(a): 
    if torch.numel(torch.tensor(a)) > 1: 
        return a[-1]
    else: 
        return a 
for idx, dir_name in enumerate(dir_names):
    accs = []
    for i in range(1,n_iter): 
        acc = process_result(torch.load(
            dir_name + "/ckpt{}.pth".format(i))['test_adv_acc'])
        accs.append(acc)
    plt.title('FPA Attack on DANN for CIFAR10' + r'$\rightarrow$' + 'CIFAR10c-Fog')
    plt.plot(range(1,n_iter), accs,label = label[idx] )
    plt.xlabel('FPA Iter')
    plt.ylabel('Acc')
    plt.legend()
    if (idx == 2):
        print(accs)
plt.savefig('./plot/figures/cifar10_FPA_adv_plot.pdf')

ax = plt.figure().gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
for idx, dir_name in enumerate(dir_names_homo):
    accs = []
    for i in range(1,n_iter): 
        acc = process_result((torch.load(
            dir_name + "/ckpt{}.pth".format(i))['test_adv_acc']))
        accs.append(acc)
    plt.title('FPA Attack on DANN for CIFAR10')
    plt.plot(range(1, n_iter), accs,label = label[idx] )
    plt.xlabel('FPA Iter')
    plt.ylabel('Acc')
    plt.legend()
    if (idx == 2):
        print(accs)
plt.savefig('./plot/figures/cifar10_FPA_homo_adv_plot.pdf')
