import numpy as np
import matplotlib.pyplot as plt
import os

# Define states and driver's test states
num_states = 5
driver_states = [0, 2]  # s0 and s2 are driver's test states

# Matplotlib settings for a professional look
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 12,
    "axes.edgecolor": "black",
    "axes.linewidth": 1.0,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "xtick.major.size": 5,
    "ytick.major.size": 5,
})

# Similarity function
def similarity(s, s_prime):
    if s == s_prime:
        return 1.0
    elif s in driver_states and s_prime in driver_states:
        return 0.8
    elif abs(s - s_prime) == 1:
        return 0.5
    else:
        return 0.2

# Build the similarity matrix
f = np.array([[similarity(i, j) for j in range(num_states)] for i in range(num_states)])

# Create figure and axis
fig, ax = plt.subplots(figsize=(6, 5))

# Plot the similarity matrix
im = ax.imshow(f, cmap='GnBu', vmin=0.2, vmax=1.0, aspect='equal')

# Annotate each cell with its value
for i in range(num_states):
    for j in range(num_states):
        text_color = 'black'
        ax.text(j, i, f"{f[i, j]:.2f}", ha='center', va='center', color=text_color)

# Set up the axes with highlighted driver's test states
ax.set_xticks(np.arange(num_states))
ax.set_yticks(np.arange(num_states))

# Create basic labels first
x_labels = [f"s{i}" for i in range(num_states)]
y_labels = [f"s{i}" for i in range(num_states)]

ax.set_xticklabels(x_labels)
ax.set_yticklabels(y_labels)

# Highlight driver's test states by changing the color of the tick labels
for label, i in zip(ax.get_xticklabels(), range(num_states)):
    if i in driver_states:
        label.set_color('magenta')
        label.set_weight('bold')

for label, i in zip(ax.get_yticklabels(), range(num_states)):
    if i in driver_states:
        label.set_color('magenta')
        label.set_weight('bold')

# Add grid lines
ax.set_xticks(np.arange(-0.5, num_states, 1), minor=True)
ax.set_yticks(np.arange(-0.5, num_states, 1), minor=True)
ax.grid(which='minor', color='white', linestyle='-', linewidth=1.5)
ax.tick_params(which='minor', bottom=False, left=False)

# Labels
ax.set_xlabel('State $s_j$')
ax.set_ylabel('State $s_i$')

# Add a note about driver's test states
ax.text(0.5, -0.15, "Driver's test states shown in magenta", 
        transform=ax.transAxes, ha='center', color='magenta', fontsize=10)

# Create results directory if it doesn't exist
os.makedirs('results', exist_ok=True)

# Save and show the plot
plt.tight_layout()
plt.savefig('results/state_similarity_matrix.pdf', dpi=300)
plt.show()
