import matplotlib.pyplot as plt
import numpy as np

plt.rcParams["axes.linewidth"] = 2

name = "appleDoor_b_prior1"
prior = np.load(f"{name}.npy")
prior = prior.sum(axis=0)
minval = np.min(prior[np.nonzero(prior)])
maxval = np.max(prior[np.nonzero(prior)])
prior[np.nonzero(prior)] = (prior[np.nonzero(prior)] - minval) / (maxval-minval) * 0.8 + 0.2
if int(name[-1]) == 1:
    aid = 2
elif int(name[-1]) == 0:
    aid = 0

# colors = ["#ff595e", "#ffca3a", "#8ac926", "#1982c4", "#6a4c93"]
# colors = np.array([[255, 89, 94], [255, 202, 58], [138, 201, 38], [25, 130, 196], [106, 76, 147]])

grey = [128, 128, 128]
colors = np.array([[255, 89, 94], [25, 130, 196], [138, 201, 38], [255, 202, 58]])
colors = 255 - colors

h, w = prior.shape
im = np.zeros((h, w, 3))


prior = np.flip(prior, axis=0)
for j in range(3):
    im[:, :, j] += prior * colors[aid][j]

    im[:2, 3, j] = grey[j]
    im[-2:, 3, j] = grey[j]
    im[:4, 7, j] = grey[j]

im = im / 255.0
im = 1 - im

fig = plt.figure(figsize=(16, 10))
ax = fig.add_subplot(111)
fig.tight_layout()
plt.imshow(im)
plt.xticks(np.arange(-0.5, 10, step=1))
plt.yticks(np.arange(-0.5, 5, step=1))
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.grid(linewidth=2, color="black")
for tick in ax.xaxis.get_major_ticks():
    tick.tick1line.set_visible(False)
    tick.tick2line.set_visible(False)
    tick.label1.set_visible(False)
    tick.label2.set_visible(False)
for tick in ax.yaxis.get_major_ticks():
    tick.tick1line.set_visible(False)
    tick.tick2line.set_visible(False)
    tick.label1.set_visible(False)
    tick.label2.set_visible(False)
plt.savefig(f"{name}_new.jpg", bbox_inches='tight')