import matplotlib as mpl
#mpl.use('Agg')
mpl.rcParams.update({'font.size': 40})
import matplotlib.pyplot as plt
import cv2
import numpy as np
import os
import seaborn as sns
import pandas as pd

class DiscreteDensity:
    def __init__(self, N=20):
        self.N = N
        self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                             1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
        self.probs = self.cat / np.sum(self.cat)

    def sample(self, num, shuffle=True):
        return np.random.choice(np.arange(self.N), num, p=self.probs)

def main():
    ns = [5, 10, 20,]
    cls = DiscreteDensity()
    true_y = cls.sample(10000)
    for n in ns:
        data = []
        sbi_file = f'sbisample{n}2.npz'
        pvi_file = f'pvisample{n}.npz'
        label_file = os.path.join('../HSP90_fix_5.0/label.npz')
        theta = np.load(label_file)['theta'][:n, ]
        if n == 20:
            plt.figure(figsize=(9, 5))
        else:
            plt.figure(figsize=(5, 5))
        sns.histplot(theta, label='True',bins=20, legend=False, stat='probability')

        if os.path.exists(sbi_file):
            ar = np.load(sbi_file)['posterior_samples']
            sns.kdeplot(ar, label='Bayes', color=sns.color_palette()[1])
        if os.path.exists(pvi_file):
            ar = np.load(pvi_file)['samples']
            sns.kdeplot(ar, label='PVI', color=sns.color_palette()[2])
        if n == 20:
            plt.legend(loc=(1.01, 0.05), frameon=False)
        plt.title(f'n={n}')
        plt.xlim([0,20])
        plt.ylim([0, 0.5])
        plt.xticks([])
        plt.yticks([])
        plt.xlabel('')
        plt.ylabel('')
        plt.tight_layout()
        plt.savefig(f'figure2/cryoem_demo{n}.pdf')
        plt.clf()



if __name__ == '__main__':
    main()
