import numpy as np
import pdb
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

def summarize(tensor, alive_dim=[0]):
    del_dim = list(set(list(range(len(tensor.shape)))) - set(alive_dim))
    tensor = (tensor.abs() + 1)
    return tensor.log().mean(dim=del_dim), tensor.log().std(dim=del_dim)

def make_bar(data,x,y,hue,ax,yerr = None, width = 0.175, **karg):
    df = data
#     hue = "setting"
#     x = "base_model"
#     y = "mse"
#     yerr = 
    x_name = x
    all_hue = df[hue].unique()

    labels = df[x].unique()

    x = np.arange(len(labels))  # the label locations
     # the width of the bar
    start_pos = x - (len(all_hue)-1)*width/2

    for i in range(len(all_hue)):
        tmp = df[df[hue] == all_hue[i]]
        tmp = tmp.set_index(x_name)

        if yerr is None:
            rects1 = ax.bar(start_pos, tmp.loc[labels,y], width, label = all_hue[i])
        else:
            rects1 = ax.bar(start_pos, tmp.loc[labels,y], width, yerr = tmp[yerr], label=all_hue[i])
        start_pos += width 
    ax.legend()
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    return ax

def get_each_df(metric_name):
    df = [[],[],[],[]]
    for name, metric in (("ve",num_ve), ("vp",num_vp), ("subvp",num_subvp)):
        metric = metric[metric_name]
        number = len(metric[0])
        df[0].extend([name] * number)
        df[1].extend(list(range(number)))
        df[2].extend(list(metric[0].numpy()))
        df[3].extend(list(metric[1].numpy()))
    df = pd.DataFrame(df).T
    df.columns = ["name", "t", "mean","std"]
    return df

subvp = np.load("result/subvp/cifar10_ddpmpp_deep_continuous/test_ckpt_18_score_dpdt.npz")
vp = np.load("result/vp/cifar10_ddpmpp_deep_continuous/test_ckpt_8_score_dpdt.npz")
ve = np.load("result/ve/cifar10_ncsnpp_deep_continuous/test_ckpt_12_score_dpdt.npz")

subvp = {k : torch.from_numpy(subvp[k]) for k in subvp}
vp = {k : torch.from_numpy(vp[k]) for k in vp}
ve = {k : torch.from_numpy(ve[k]) for k in ve}

num_subvp = {k : summarize(subvp[k]) for k in subvp}
num_vp = {k : summarize(vp[k]) for k in vp}
num_ve = {k : summarize(ve[k]) for k in ve}

for name in num_subvp.keys():
    df = get_each_df(name)
    df[['mean','std']] = df[['mean','std']].astype(float)
    for metric in ['mean','std']:
        f, ax = plt.subplots(1, 1, figsize=(5,5))
        # df[['mean','std']] = np.log(df[['mean','std']].astype(float) + 1)
        # make_bar(df,"t","mean","name",ax,"std")
        ax = sns.lineplot(data=df,x="t", y=metric, hue="name", ax = ax)
        plt.tight_layout()
        plt.savefig(f'score_fig/{name}_{metric}.pdf')    
        plt.close()


pdb.set_trace()

