# ---------------------------------------------------------------------
# 0.  Publication settings: Computer-Modern + PDF backend
# ---------------------------------------------------------------------
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
import matplotlib as mpl
mpl.rcParams.update({
    "text.usetex": True,                 # use LaTeX for all text
    "font.family":  "serif",
    "font.serif":   ["Computer Modern"], # matches typical paper body font
    "pdf.fonttype": 42,                  # embed fonts as TrueType (max compatibility)
    "ps.fonttype":  42
})

path_mean = f"{BASE_PATH}/scripts/notebooks/data/corelogic/neurips/diff_matrix_set_seq_nn_mean_std_may14.npy"
path_std = f"{BASE_PATH}/scripts/notebooks/data/corelogic/neurips/std_diff_matrix_set_seq_nn_may14.npy"

import numpy as np
import matplotlib.pyplot as plt

# Load the data
data_mean = np.load(path_mean)
data_std = np.load(path_std)

transition_freq = [
    [165,    17,     1,  1433, 10743,  395,   552],
    [1104,  125,   297, 22702,  1982,   50,  1092],
    [596,   744,  3013,  2312,    73,  np.nan,    92],
    [4926, 8466,  3314,   np.nan,   np.nan,  np.nan,   342],
    [955508,7361,  np.nan,   np.nan,   np.nan,  np.nan, 21639]
]
x_labels = ["Current", "30dd", "60dd", "90dd", "F", "REO", "Paid Off"]
y_labels = ["F", "90dd", "60dd", "30dd", "Current"]
y_title = "Initial State"
x_title = "End State"

save_path = f"{BASE_PATH}/scripts/notebooks/data/corelogic/neurips/performance_diff.pdf"


import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

# ---------------------------------------------------------------------
# 1.  Pre-process inputs
# ---------------------------------------------------------------------
mean  = np.asarray(data_mean, dtype=float)
std   = np.asarray(data_std,  dtype=float)
freq  = np.asarray(transition_freq, dtype=float)

# Circle size ∝ log(freq); treat anything <100 as the same minimum size
freq_clipped  = np.where(freq < 100, 100, freq)
log_freq      = np.log(freq_clipped)
log_min, log_max = np.nanmin(log_freq), np.nanmax(log_freq)

# Sizes in scatter are **areas in pt²**; tune these two so circles stay inside cells
S_MIN, S_MAX = 160, 2600                       # pt²  (≈ marker radius 7-29 pt)
sizes = S_MIN + (S_MAX - S_MIN) * (log_freq - log_min)/(log_max - log_min)

# Normalise colour intensity by the largest |mean|
abs_max      = np.nanmax(np.abs(mean)) or 1.0      # guard against all-zero
alpha_raw    = np.abs(mean) / abs_max              # 0 … 1
MIN_ALPHA    = 0.30                                # show faint colour even for tiny values
alpha        = MIN_ALPHA + (1 - MIN_ALPHA) * alpha_raw

# ---------------------------------------------------------------------
# 2.  Plot
# ---------------------------------------------------------------------
nr, nc = mean.shape
fig, ax = plt.subplots(figsize=(1.0*nc, 1.0*nr), dpi=300)  # 1×1 in per cell
# Helper: print with ≤ 2 non-zero digits (significant figures)
# ---------------------------------------------------------------------
def fmt_two_sig(x):
    if np.isnan(x):
        return ""
    return f"{x:.2g}"          # e.g. 0.00456 → '0.0046', 12.345 → '12', 0 → '0'
# Draw cell borders (light grey grid)
for i in range(nr):
    for j in range(nc):
        ax.add_patch(plt.Rectangle((j, i), 1, 1, fill=False, lw=.5, ec='grey'))

# Scatter circles + text
for i in range(nr):
    for j in range(nc):
        if np.isnan(mean[i, j]):
            continue                                   # skip masked cells

        # --- circle ---
        colour = ('green' if mean[i, j] >= 0 else 'red')
        ax.scatter(j + .5, i + .5,
                   s=sizes[i, j],
                   color=colour,
                   alpha=alpha[i, j],
                   edgecolors='none')
    

        # --- mean ± std text ---
        #txt = f"{mean[i, j]:.3g}\n({std[i, j]:.3g})"
        #txt_colour = 'white' if alpha[i, j] > .7 else 'black'
       # ax.text(j + .5, i + .5, txt,
       #         ha='center', va='center',
       #         fontsize=7.5, color=txt_colour)
       # ---------------------------------------------------------------------
        txt = f"{fmt_two_sig(mean[i, j])}\n({fmt_two_sig(std[i, j])})"
        txt_y = i + 0.72                    # slightly above the bottom border
        ax.text(j + 0.5, txt_y, txt,
                ha='center', va='top',      # anchor the *top* of the text at txt_y
                fontsize=7.5, color='black',
                zorder=3)                   # draw on top of the circle

# ---------------------------------------------------------------------
# Inside the i-j loop, replace the current text block with this
# ---------------------------------------------------------------------

TICK_FONTSIZE = 11     # was 8
LABEL_FONTSIZE = 13    # was 9
# ---------------------------------------------------------------------
# 3.  Axes cosmetics
# ---------------------------------------------------------------------
ax.set_xlim(0, nc)
ax.set_ylim(nr, 0)
ax.set_xticks(np.arange(nc) + .5)
ax.set_yticks(np.arange(nr) + .5)
ax.set_xticklabels(x_labels, rotation=45, ha='right', fontsize=TICK_FONTSIZE)
ax.set_yticklabels(y_labels, fontsize=TICK_FONTSIZE)
ax.set_xlabel(x_title,  fontsize=LABEL_FONTSIZE)
ax.set_ylabel(y_title,  fontsize=LABEL_FONTSIZE)
ax.set_aspect('equal')
ax.tick_params(axis='both', length=0)          # hide tick marks

fig.tight_layout(pad=0.4)
fig.savefig(save_path, format='pdf', bbox_inches='tight')
plt.close(fig)
