import numpy as np
import matplotlib.pyplot as plt
from scipy.special import logsumexp

np.random.seed(1)
# Generate random data for demonstration
n = 5  # Number of categories in the upper layer
d = 10  # Number of categories in the lower layer
upper_layer_dist_target = np.random.dirichlet(np.ones(n) * 0.5)
lower_layer_dist_target = np.random.dirichlet(np.ones(d) * 0.5, size=n)

S = 1
upper_layer_phi = np.load(f'results/5_10_05_seed_1/upper_dist_{S}.npy')
lower_layer_phi = np.load(f'results/5_10_05_seed_1/lower_dist_{S}.npy')
kl = np.load(f'results/5_10_05_seed_1/kl_lists.npy')[S-1][-1]

upper_layer_components = np.exp(upper_layer_phi - logsumexp(upper_layer_phi, axis=-1, keepdims=True))
upper_layer_dist = upper_layer_components.sum(0) / S
lower_layer_components = np.exp(lower_layer_phi - logsumexp(lower_layer_phi, axis=-1, keepdims=True))
lower_layer_dist = lower_layer_components.sum(0) / S

# Set the figure size
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 10), sharex=True)

colors = ['red', 'blue', 'pink', 'orange', 'purple']
# Plotting the upper layer distribution
for i, x in enumerate(range(0, n * d, d)):
    ax1.axvline(x - 0.5, 0, 1)
    bottom = 0
    for s in range(S):
        ax1.bar(x - 0.5 + d / 2, upper_layer_components[s, i] / S, alpha=1., width=d / 2,
                align='edge', bottom=bottom, color=colors[s])
        bottom += upper_layer_components[s, i] / S
    ax1.bar(x - 0.5, upper_layer_dist_target[i], color='black', alpha=1., width=d / 2, align='edge')
    ax1.axvline(x -0.5 + d, 0, 1)

    ax2.axvline(x - 0.5, 0, 1)
    bottom_lower = np.zeros(d)
    for s in range(S):
        ax2.bar(x + np.arange(d), lower_layer_components[s, i] / S, alpha=1., width=0.5,
                align='edge', bottom=bottom_lower, color=colors[s])
        bottom_lower += lower_layer_components[s, i] / S
    ax2.bar(x + np.arange(d), lower_layer_dist_target[i], alpha=1., width=-0.5, align='edge', color='black')
    ax2.axvline(x + d - 0.5, 0, 1)
    ax1.set_ylim(0, 0.6)
if S == 1:
    ax1.set_ylabel(r"$p(z_1)$", fontsize=20, color='black')
    ax2.set_ylabel(r"$p(z_2|z_1)$", fontsize=20, color='black')
ax2.set_xlabel(f"KL$(p||q)={round(kl,3)}$", fontsize=20)
ax1.set_title(f"Two-Level Hierarchical Model ($S={S}$)", fontsize=20)
ax2.set_xticks(np.arange(n * d))
ax2.set_xticklabels(['{}'.format(i % d) for i in range(n * d)])


# Adjust the spacing between subplots
plt.subplots_adjust(hspace=0.01)

# Display the plot
plt.savefig(f'results/5_10_05_seed_1/plotted_dist_{S}.pdf', format='pdf')
plt.show()
