# ===================
# Part 1: Importing Libraries
# ===================
import matplotlib.pyplot as plt

# ===================
# Part 2: Data Preparation
# ===================
import numpy as np

np.random.seed(0)

# Define the data with correct shape (Success rates in percentage)
data = np.array(
    [
        [75.8, 62.3, 81.5, 90.0, 58.7, 77.6],
        [88.3, 77.9, 55.8, 65.4, 72.7, 80.1],
        [60.5, 70.8, np.nan, 85.3, 78.9, np.nan],
        [82.1, 85.4, 74.6, 92.0, 88.2, 76.5],
    ]
)

title = "Average Success Rates of Legal Cases Across Law Firms"
xlabel = "Types of Legal Cases"
xticklabels = ["Criminal", "Civil", "Employment", "Family", "Immigration", "Corporate"]
yticklabels = ["Firm A", "Firm B", "Firm C", "Firm D"]
ylabel = "Law Firms"

# ===================
# Part 3: Plot Configuration and Rendering
# ===================
# Plotting the heatmap with adjusted colorbar and new theme color
# Create mask for NaN values to hatch them later
mask = np.isnan(data)

# Defining a new color palette
cmap = plt.get_cmap("viridis")
norm = plt.Normalize(vmin=np.nanmin(data), vmax=np.nanmax(data))

fig, ax = plt.subplots(figsize=(10, 8))
cax = ax.imshow(data, cmap=cmap, norm=norm)
cbar = fig.colorbar(cax, ax=ax, extend="both")

# Add hatches for NaN values
for i, j in zip(*np.where(mask)):
    ax.add_patch(
        plt.Rectangle(
            (j - 0.5, i - 0.5), 1, 1, fill=False, hatch="//", edgecolor="black"
        )
    )

# Adding titles and labels
plt.title(title, fontsize=16, fontweight="bold")
plt.xlabel(xlabel, fontsize=14)
plt.ylabel(ylabel, fontsize=14)

# Define the labels for x and y axis
ax.set_xticks(range(6))
ax.set_xticklabels(xticklabels, rotation=45, ha="right", fontsize=12)
ax.set_yticks(range(4))
ax.set_yticklabels(yticklabels, rotation=0, fontsize=12)

# Add annotations
for i in range(4):
    for j in range(6):
        if not np.isnan(data[i, j]):
            if data[i, j] > np.nanmean(data) * 1.2:
                ax.text(
                    j,
                    i,
                    f"{data[i, j]:.1f}",
                    ha="center",
                    va="center",
                    color="white",
                    fontweight="bold",
                )
            else:
                ax.text(
                    j, i, f"{data[i, j]:.1f}", ha="center", va="center", color="black"
                )

# ===================
# Part 4: Saving Output
# ===================
# Displaying the plot with tight layout to minimize white space
plt.tight_layout()
plt.savefig("heatmap_48.pdf", bbox_inches="tight")
