import trackexp as tx
import matplotlib.pyplot as plt
import pandas

dfs = {}
losses = ["CE", "FW"]

loss_curve_color = { "CE" : 'k', "FW": 'r'}
loss_curve_style = { "CE" : '--', "FW": '-'}


for loss_name in losses:
    dfs[loss_name] = tx.get_data(f"cifar100_resnet18_{loss_name}", "training")
plt.figure(figsize=(4,2.25))
plt.rcParams.update({'font.size': 12})

what_to_plot = 'test_acc_averaged model'


for ln in losses:
    plt.plot(dfs[ln]['wallclocktime'], dfs[ln][what_to_plot] , 
             label=f"{ln} ({dfs[ln][what_to_plot].iloc[-1]})", 
             color = loss_curve_color[ln],
             linestyle = loss_curve_style[ln])
plt.legend(title="Loss (test acc)")
plt.title("CIFAR-100, ResNet18, DoG optim")

plt.xlabel('wall clock time (seconds)')
plt.ylabel('Test accuracy')
plt.grid(True)
plt.tight_layout()
filename = 'DoG_cifar100_resnet18.pdf'
plt.savefig(filename)



import trackexp as tx
import matplotlib.pyplot as plt
plt.figure(figsize=(4,2.25))
plt.rcParams.update({'font.size': 12})

exp_names = [exp_info['name'] for exp_info in tx.list_experiments() if exp_info['name'].startswith('exp_')]

exp_lf = {}
for exp_name in exp_names:
    tx.get_data(exp_name, "training")  # Test if training table exists
    exp_lf[exp_name] = tx.get_metadata(exp_name)['loss_func']

colors = {"CE":'k' , "FW":'r'}
has_been_plotted = {"CE":False, "FW":False}
for name, lf in exp_lf.items():
    df = tx.get_data(name, "training")
    if not has_been_plotted[lf]:
        plt.plot(df['test_acc_averaged model'], color = colors[lf], label=lf)
        has_been_plotted[lf]=True
    else:
        plt.plot(df['test_acc_averaged model'], color = colors[lf])

plt.legend(title="Loss (test acc)")
plt.title("CIFAR-100, ResNet18, DoG optim")

plt.xlabel('Epochs')
plt.ylabel('Test accuracy')
plt.grid(True)
plt.tight_layout()
filename = 'DoG_cifar100_resnet18_reps.pdf'
plt.savefig(filename)

