import matplotlib as mpl
mpl.use('Agg')
mpl.rcParams.update({'font.size': 22})
import matplotlib.pyplot as plt
import cv2
import numpy as np
import os
import seaborn as sns
import pandas as pd

class DiscreteDensity:
    def __init__(self, N=20):
        self.N = N
        self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                             1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
        self.probs = self.cat / np.sum(self.cat)

    def sample(self, num, shuffle=True):
        return np.random.choice(np.arange(self.N), num, p=self.probs)

def main():
    cls = DiscreteDensity()
    true_y = cls.sample(10000)
    pvi_sample = np.load('result/HSP90NoRot_5.0_10000/PVI_1_32/sample.npz')['samples']
    sbi_data = np.load('result/HSP90NoRot_5.0_10000/SBI_2/sample.npz')
    sbi_one_sample = sbi_data['one_samples']
    sbi_mean_sample = sbi_data['mean_samples']
    label_file = os.path.join('../HSP90_fix_5.0/label.npz')
    theta = np.load(label_file)['theta'].flatten()

    # sns.scatterplot(data = plotdata, x='beta1', y='beta2', s = 1, color = ".2")
    plt.rcParams.update({'font.size': 22})
    plt.figure(figsize=(5, 4))
    #sns.kdeplot(true_y, label='True',)
    sns.histplot(theta, label='True', bins=20, legend=False, stat='probability')
    sns.kdeplot(pvi_sample, label='PVI', legend=False)
    #sns.kdeplot(sbi_one_sample, label='_nolegend_', legend=False)
    #sns.kdeplot(sbi_mean_sample, label='_nolegend_', legend=False)
    plt.legend()
    plt.xticks([])
    plt.yticks([])
    plt.xlim([0, 20])
    plt.xlabel('')
    plt.ylabel('')
    plt.title('Parameter distribution')
    plt.tight_layout()
    plt.savefig(f'figure2/ce_dist.pdf')




if __name__ == '__main__':
    main()