import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.patches import Patch

def calculate_dqs_delta(s, n, K, r):
    """
    Calculates the reward/penalty term (Delta) of the DQS metric.
    Delta = DQS(s, n) - s
    """
    # Ensure s is not exactly 1 to avoid division by zero in (1-s)
    s = np.clip(s, 0, 0.99999)
    denominator = n + K
    # Handle cases where denominator might be zero (though n>=0, K>0 makes it unlikely)
    denominator[denominator == 0] = 1e-6
    
    delta = (s / denominator) * (r * K - (1 - s) * n)
    return delta

# --- 1. Define Parameters ---
# K: Mean step count for the cohort.
K = 30.0 
# s_mean: Mean score for the cohort.
s_mean = 0.7
# r: Cohort difficulty factor, r = 1 - s_mean.
r = 1.0 - s_mean

# --- 2. Create Grid Data based on Deviations ---
# x-axis: Deviation from mean score (s_dev)
# Let's define a range for s_dev, e.g., from -0.5 to +0.3
s_dev_vals = np.linspace(-0.3, 0.3, 200) # s ranges from 0 to 1

# y-axis: Deviation from mean step count (n_dev)
# Let's define a range for n_dev, e.g., from -K to +2K
n_dev_vals = np.linspace(-30, 30, 200) # n ranges from 0 to 3K

# Create a meshgrid for deviations
S_DEV, N_DEV = np.meshgrid(s_dev_vals, n_dev_vals)

# --- 3. Convert Deviations back to Absolute Values to Calculate Delta ---
# s = s_dev + s_mean
# n = n_dev + K
S_ABS = S_DEV + s_mean
N_ABS = N_DEV + K

# Ensure absolute values are non-negative
S_ABS[S_ABS < 0] = 0
N_ABS[N_ABS < 0] = 0

# --- 4. Calculate Delta for each point on the grid ---
Delta = calculate_dqs_delta(S_ABS, N_ABS, K, r)

# --- 5. Create the 3D Plot ---
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(111, projection='3d')

# --- 6. Customize Colormap ---
# Assign colors based on the sign of Delta
blue_color = '#0f87be'   # Penalty
orange_color = '#ff964b' # Reward
face_colors = np.empty(S_DEV.shape, dtype=object)
face_colors[Delta < 0] = blue_color
face_colors[Delta >= 0] = orange_color

# Plot the surface using deviation values for axes
# Use a smaller stride for a smoother surface
surf = ax.plot_surface(S_DEV, N_DEV, Delta, facecolors=face_colors,
                       linewidth=0, antialiased=True, alpha=0.7,
                       rstride=1, cstride=1) # Smaller stride for denser grid

# --- 7. Add Labels and Title ---
ax.set_xlabel('Score Deviation (s - s_mean)', fontsize=18, labelpad=18)
ax.set_ylabel('Step Count Deviation (n - K)', fontsize=18, labelpad=18)
ax.set_zlabel('Reward / Penalty (Δ)', fontsize=18, labelpad=18)
ax.set_title('DQS Reward-Penalty Surface (Centered at Mean)', fontsize=24, pad=20)

# Add zero-plane for reference
ax.plot_surface(S_DEV, N_DEV, np.zeros_like(S_DEV), alpha=0.4, color='gray', linewidth=0)

# Add lines for the axes origins
ax.plot([0, 0], [np.min(N_DEV), np.max(N_DEV)], [0, 0], color='black', linestyle='--', linewidth=1.5, alpha=0.7)
ax.plot([np.min(S_DEV), np.max(S_DEV)], [0, 0], [0, 0], color='black', linestyle='--', linewidth=1.5, alpha=0.7)


# Improve viewing angle for better quadrant visibility
ax.view_init(elev=50, azim=135)

# Add custom legend
legend_elements = [Patch(facecolor=orange_color, edgecolor='k', label='Reward (Δ > 0)'),
                   Patch(facecolor=blue_color, edgecolor='k', label='Penalty (Δ < 0)')]
ax.legend(handles=legend_elements, loc='upper left', fontsize=18)



plt.tight_layout(pad=5.0)


plt.savefig("dqs_surface_centered.png", dpi=300, bbox_inches='tight')
plt.show()