import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_palette('colorblind') 
from matplotlib.ticker import MaxNLocator


dir_names  = ['results/DANN_cifar10_cifar10c/DANN_Bilevel']

label = ['Rand Init']
ax = plt.figure().gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))

def process_result(a): 
    if torch.numel(torch.tensor(a)) > 1: 
        return a[-1]
    else: 
        return a 
for idx, dir_name in enumerate(dir_names):
    accs = []
    for i in range(1,300): 
        acc = process_result(torch.load(
            dir_name + "/ckpt{}.pth".format(i))['train_target_acc'])
        accs.append(acc)
    plt.title('Alternating Bilevel Optimizaion Attack on DANN for CIFAR10' + r'$\rightarrow$' + 'CIFAR10c-Fog')
    plt.plot(range(1,300), accs,label = label[idx] )
    plt.xlabel('Epoch')
    plt.ylabel('Acc')
    plt.legend()
    if (idx == 2):
        print(accs)
plt.savefig('./plot/figures/cifar10_bilevel_plot.pdf', bbox_inches='tight')
