import matplotlib.pyplot as plt
import torch
import matplotlib
def max_before(y):
    y_max = torch.zeros_like(y)
    for i in range(y.shape[0]):
        y_max[i] = torch.max(y[:i+1])
    return y_max
if __name__ == "__main__":
    font = {'weight': 'normal',
            'size': 18}
    ex_id_list = list(range(10))
    matplotlib.rc('font', **font)
    name_map = {"NUCB": "MFNUCB", "UCBTS": "UCBTS", "NVUCB": "MFNVUCB", "UCB": "SepGP"}
    name_list = ["NUCB", "UCB", "UCBTS", "NVUCB"]
    # name_list = ["NVUCB1", "NUCB1", "UCB1"]#, "NVUCB1", "NUCB1", "UCB1"]
    for name in name_list:
        y_t_ex = torch.zeros([10, 24])
        for ex_id in ex_id_list:
            y_t = torch.load(f"./saves/ex_{ex_id}_res_{name}Y_true")
            y_t = max_before(y_t)

            y = torch.load(f"./saves/ex_{ex_id}_res_{name}Y")
            x = torch.load(f"./saves/ex_{ex_id}_res_{name}x")
            x_i = torch.load(f"./saves/ex_{ex_id}_res_{name}ind_x")
            y_t_ex[ex_id, :] = y_t
        # plt.plot(torch.arange(y_t_ex.shape[1]), y_t_ex.mean(dim=0), label=name)
        plt.errorbar(torch.arange(y_t_ex.shape[1]),y_t_ex.mean(dim=0), linewidth=4 ,yerr = torch.sqrt(y_t_ex.var(dim=0)), fmt='-o', markersize=8, capsize=10, capthick=3, label=name_map[name], elinewidth=3)
    plt.legend()
    plt.ylim([0.85, 1])
    plt.xlim([3, 24])
    plt.savefig('2_fids.png')
    plt.show()



    name_map = {"NUCB0": "NUCB", "UCB0":"UCB", "UCBC0":"iUCB", "NVUCB0":"NVUCB"}
    name_list = ["NUCB0", "UCB0", "UCBC0", "NVUCB0"]
    # name_list = ["NVUCB1", "NUCB1", "UCB1"]#, "NVUCB1", "NUCB1", "UCB1"]
    for name in name_list:
        y_t_ex = torch.zeros([10, 24])
        for ex_id in ex_id_list:
            y_t = torch.load(f"./saves/ex_{ex_id}_res_{name}Y_true")
            y_t = max_before(y_t)

            y = torch.load(f"./saves/ex_{ex_id}_res_{name}Y")
            x = torch.load(f"./saves/ex_{ex_id}_res_{name}x")
            x_i = torch.load(f"./saves/ex_{ex_id}_res_{name}ind_x")
            y_t_ex[ex_id, :] = y_t
        # plt.plot(torch.arange(y_t_ex.shape[1]), y_t_ex.mean(dim=0), label=name)
        plt.errorbar(torch.arange(y_t_ex.shape[1]), y_t_ex.mean(dim=0), linewidth=4, yerr=torch.sqrt(y_t_ex.var(dim=0)),
                     fmt='-o', markersize=8, capsize=10, capthick=3, label=name_map[name], elinewidth=3)
    plt.legend()
    plt.ylim([0.85, 1])
    plt.xlim([3, 24])
    plt.savefig('0_fid.png')
    plt.show()

    fig = plt.figure()
    gs = fig.add_gridspec(3, hspace=0)
    axs = gs.subplots(sharex=True, sharey=True)
    sub_map = {"NVUCB": axs[0], "NUCB": axs[1], "UCB": axs[2]}
    name_list = ["NVUCB", "NUCB", "UCB"]
    name_map = {"NVUCB":"MFNVUCB", "NUCB":"MFNUCB", "UCB":"SepGP"}
    c_map = {0:'r', 1:'b'}
    for name in name_list:
        for ex_id in ex_id_list:
            y = torch.load(f"./saves/ex_{ex_id}_res_{name}Y")
            x = torch.load(f"./saves/ex_{ex_id}_res_{name}x")
            ind = torch.load(f"./saves/ex_{ex_id}_res_{name}ind_x")
            for xi in range(x.shape[0]):
                sub_map[name].scatter(x[xi], y[xi], c=c_map[int(ind[xi])], alpha=0.4)
        # sub_map[name].scatter(x[0], y[0], c=c_map[int(ind[0])], alpha=0, label=name)
        sub_map[name].set_title(name_map[name], y=1.0, pad=-20)
        sub_map[name].set_xlim([0, 1])
    plt.savefig('fid_choice.png')
    plt.show()







    name_map = {"NUCB1": "NUCB", "UCB1":"UCB", "UCBC1":"iUCB", "NVUCB1":"NVUCB"}
    name_list = ["NUCB1", "UCB1", "UCBC1", "NVUCB1"]
    # name_list = ["NVUCB1", "NUCB1", "UCB1"]#, "NVUCB1", "NUCB1", "UCB1"]
    for name in name_list:
        y_t_ex = torch.zeros([10, 24])
        for ex_id in ex_id_list:
            y_t = torch.load(f"./saves/ex_{ex_id}_res_{name}Y_true")
            y_t = max_before(y_t)

            y = torch.load(f"./saves/ex_{ex_id}_res_{name}Y")
            x = torch.load(f"./saves/ex_{ex_id}_res_{name}x")
            x_i = torch.load(f"./saves/ex_{ex_id}_res_{name}ind_x")
            y_t_ex[ex_id, :] = y_t
        # plt.plot(torch.arange(y_t_ex.shape[1]), y_t_ex.mean(dim=0), label=name)
        plt.errorbar(torch.arange(y_t_ex.shape[1]), y_t_ex.mean(dim=0), linewidth=4, yerr=torch.sqrt(y_t_ex.var(dim=0)),
                     fmt='-o', markersize=8, capsize=10, capthick=3, label=name_map[name], elinewidth=3)
    plt.legend()
    plt.ylim([0.85, 1])
    plt.xlim([3, 24])
    plt.savefig('1_fid.png')
    plt.show()