
# ===================
# Part 1: Importing Libraries
# ===================
import matplotlib.pyplot as plt

# ===================
# Part 2: Data Preparation
# ===================
import numpy as np

np.random.seed(0)

# Placeholder data for confusion matrices
data_study_a = np.array(
    [
        [50, 2, 1, 0, 0, 0],
        [10, 45, 5, 0, 0, 0],
        [5, 10, 40, 5, 0, 0],
        [0, 0, 5, 50, 3, 1],
        [0, 0, 0, 10, 35, 5],
        [0, 0, 0, 0, 10, 40],
    ]
)
data_study_b = np.array(
    [
        [45, 5, 0, 0, 0, 0],
        [5, 40, 10, 0, 0, 0],
        [0, 5, 35, 5, 5, 0],
        [0, 0, 5, 45, 5, 0],
        [0, 0, 0, 5, 40, 5],
        [0, 0, 0, 0, 5, 45],
    ]
)

# Titles for the subplots
titles = ["Study A", "Study B"]
ylabel = "Actual Treatment"
xlabel = "Predicted Treatment"

# ===================
# Part 3: Plot Configuration and Rendering
# ===================
# Increase the figure height and adjust subplot layout
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6, 5))

# Function to create a single confusion matrix plot
def plot_confusion_matrix(ax, data, title):
    im = ax.imshow(data, interpolation="nearest", cmap="YlGnBu")
    ax.set(
        title=title,
        ylabel=ylabel,
        xlabel=xlabel,
        xticks=np.arange(data.shape[1]),
        yticks=np.arange(data.shape[0]),
    )
    ax.axhline(y=2.5, color="white", linewidth=3)
    ax.axvline(x=2.5, color="white", linewidth=3)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    return im

# Plot each confusion matrix
im1 = plot_confusion_matrix(axes[0], data_study_a, titles[0])
im2 = plot_confusion_matrix(axes[1], data_study_b, titles[1])

# Adjust the position and size of the colorbars
cbar_ax1 = fig.add_axes([0.1, 0.15, 0.35, 0.02])  # Adjusted for the first subplot
cbar_ax2 = fig.add_axes([0.55, 0.15, 0.35, 0.02])  # Adjusted for the second subplot
fig.colorbar(im1, cax=cbar_ax1, orientation="horizontal")
fig.colorbar(im2, cax=cbar_ax2, orientation="horizontal")

# ===================
# Part 4: Saving Output
# ===================
# Displaying the plot with tight layout to minimize white space
fig.tight_layout()
plt.savefig("heatmap_65.pdf", bbox_inches="tight")
