import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

path = "./store/results/augment/cifar100/"

plt.figure(figsize=(10, 10))
for method in ["Finetune", "Oracle", "ER"]:
    a = torch.load(path + method + "_avg_acc.pt")
    segment_len = 500
    # print(method, a[500, :])
    # sns.lineplot(a.mean(1).cpu().detach(), label=method)
    seg_acc = []
    for i in range(0, a.shape[0], segment_len):
        if i + segment_len >= a.shape[0]:
            break
        seg_acc.append(
            (a[i + segment_len, :] * (i + segment_len) - a[i, :] * i).mean(0)
            / segment_len
        )
    sns.lineplot(np.arange(0, a.shape[0], segment_len), seg_acc, label=method)
    # a = torch.load(path + method + "_test_acc_list.pt")
    # print(method, a.mean(1), a.mean())
plt.show()
