import json
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


def conf_vis(json_name):
    save_name = json_name.replace('.json', '.jpg')
    with open(json_name, 'r') as f:
        jsc = json.load(f)
    
    id_conf = jsc['id_conf']
    ood_conf = jsc['ood_conf']
    ood_type = json_name.split('/')[-1].replace('.json', '')

    sns.set_style('dark')
    sns.distplot(ood_conf, hist=False, kde=True, norm_hist=True, kde_kws={'linestyle':'-', 'linewidth': 1, 'fill':True}, label=ood_type[4:])
    sns.distplot(id_conf, hist=False, kde=True, norm_hist=True, kde_kws={'linestyle':'-', 'linewidth': 1, 'fill':True}, label='ImageNet')
    plt.tick_params(labelsize=13)
    plt.ylabel('Density', fontsize=13)
    plt.legend(markerscale=2.2, fontsize='13', loc='upper right')
    plt.yticks([])
    plt.savefig(save_name, dpi=300)



if __name__ == '__main__':
    # json_name = './conf_results/iNaturalist.json'
    # json_name = './conf_results/SUN.json'
    # json_name = './conf_results/Places.json'
    # json_name = './conf_results/Texture.json'

    # json_name = './conf_results/MSP_iNaturalist.json'
    # json_name = './conf_results/MSP_SUN.json'
    # json_name = './conf_results/MSP_Places.json'
    # json_name = './conf_results/MSP_Texture.json'



    json_name = './conf_results/cifar/MSP_SVHN.json'
    json_name = './conf_results/cifar/MSP_iSUN.json'
    json_name = './conf_results/cifar/MSP_Places.json'
    conf_vis(json_name)

