import pathlib
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from imageio.v3 import imread
import json, os
from skimage.color import lab2rgb
from numba import njit, prange
import matplotlib.gridspec as gridspec

matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 300})
# plt.rcParams.update(figsizes.neurips2023())
# plt.rcParams.update(fonts.neurips2023())
plt.rcParams.update({"font.size": 20})
plt.rcParams["font.family"] = "Times New Roman"

def my_left_corner_crop(x):
    return x[:2048, :2048]

@njit(parallel=True)
def apply_threshold_array(sr, threshold, min_vals, max_vals):
    """
    Apply a symmetric threshold to each pixel value of an image to get lower and upper bounds.
    Clamps the resulting values to [min_vals, max_vals] for each channel (to stay in valid Lab range).
    Returns two arrays: lower bounds and upper bounds, both same shape as sr.
    """
    lower = np.empty_like(sr, dtype=np.float32)
    upper = np.empty_like(sr, dtype=np.float32)
    for i in prange(sr.shape[0]):
        for j in range(sr.shape[1]):
            for k in range(sr.shape[2]):
                val = sr[i, j, k]
                low_val = val - threshold
                high_val = val + threshold
                # Clamp to valid range per channel
                if low_val < min_vals[k]:
                    low_val = min_vals[k]
                if high_val > max_vals[k]:
                    high_val = max_vals[k]
                lower[i, j, k] = low_val
                upper[i, j, k] = high_val
    return lower, upper

root_path = pathlib.Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename

angelopoulos_path = root_path / "masks_angelopoulos_new"
# angelopoulos_imgs = angelopoulos_path / f"{dataset_basename}_intervals"

kutiel_path = root_path / "masks_kutiel"

fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)

valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()

NORM_MIN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
NORM_MAX = np.array([1, 1, 1], dtype=np.float32)
LAB_MIN = np.array([0.0, -128.0, -128.0], dtype=np.float32)
LAB_MAX = np.array([100.0, 127.0, 127.0], dtype=np.float32)

chosen_ids = [110, 62, 32, 248]
EPS = 0.001
IS = [valid_fnames[i] for i in chosen_ids]
LEVEL = "0.1" 
with open(root_path / "conformal/sinsr-thresholds-s64-dp64-bs.json") as file:
    THRESHOLD_NSEM = json.load(file)[LEVEL] 

thresholds_path = angelopoulos_path / f"LIU4K_v2_train_thresholds.json"
with open(thresholds_path, "r", encoding="utf-8") as f:
    THRESHOLD_ANG = json.load(f)[LEVEL]
    
thresholds_path = kutiel_path / f"kutiel_LIU4K_v2_train_thresholds.json"
with open(thresholds_path, "r", encoding="utf-8") as f:
    THRESHOLD_KUT = json.load(f)[LEVEL]

LEVEL_SEM = "0.18749999999999997"
with open(root_path / "conformal/sinsr-thresholds-s64-d128-bs.json") as file:
    THRESHOLD_SEM = json.load(file)[LEVEL_SEM]

n_rows = len(IS)
n_cols = 6

width_ratios = [1, 1, 1, 1, 0.5, 1]  # Reduce 5th column width to create effect of tighter space

fig = plt.figure(figsize=(n_cols * 4.15, 4 * n_rows))
outer_gs = gridspec.GridSpec(n_rows, 2, width_ratios=[4, 2], wspace=0.075/2, hspace=0.15)

axs = np.empty((n_rows, 6), dtype=object)

for row, I in enumerate(IS):

    # Left block: 4 columns with standard spacing
    left_gs = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=outer_gs[row, 0], wspace=0.15, hspace=0.15)
    axs[row, 0] = fig.add_subplot(left_gs[0, 0])
    axs[row, 1] = fig.add_subplot(left_gs[0, 1])
    axs[row, 2] = fig.add_subplot(left_gs[0, 2])
    axs[row, 3] = fig.add_subplot(left_gs[0, 3])

    # Right block: 2 columns with tighter spacing
    right_gs = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_gs[row, 1], wspace=0.075/8, hspace=0.15)
    axs[row, 4] = fig.add_subplot(right_gs[0, 0])
    axs[row, 5] = fig.add_subplot(right_gs[0, 1])


    real_image = my_left_corner_crop(imread(dataset_path / f"{I}.png")) / 255
    sr_img = np.load(output_path / f"{I}_pred.npy")
    sr_rgb = lab2rgb(sr_img)
    mask = np.load(output_path / f"{I}_s64_varmask.npy")
    sigma = np.load(output_path / f"{I}_simple_varmask.npy")
    kutiel = np.minimum(THRESHOLD_KUT / (1 - sigma / np.max(sigma) + EPS), 1)

    lower, upper = apply_threshold_array((sr_img - LAB_MIN) / (LAB_MAX - LAB_MIN),
                                          np.float32(THRESHOLD_ANG), NORM_MIN, NORM_MAX)
    lower = lower * (LAB_MAX - LAB_MIN) + LAB_MIN
    upper = upper * (LAB_MAX - LAB_MIN) + LAB_MIN

    ALPHA = 0.3

    axs[row, 0].imshow(real_image)
    red_mask = np.stack((np.ones((2048, 2048)), np.zeros((2048, 2048)), np.zeros((2048, 2048))), axis=2)
    selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD_NSEM
    axs[row, 1].imshow(np.where(selection, (1 - ALPHA) * red_mask + ALPHA * sr_rgb, sr_rgb))
    selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD_SEM
    axs[row, 2].imshow(np.where(selection, (1 - ALPHA) * red_mask + ALPHA * sr_rgb, sr_rgb))
    axs[row, 3].imshow(1-kutiel, cmap="gray")
    axs[row, 4].imshow(lab2rgb(lower))
    axs[row, 5].imshow(lab2rgb(upper))

for ax in np.ravel(axs):
    ax.set_xticks([])
    ax.set_yticks([])
    

axs[0, 0].set_title("Ground truth", y=1.05)
axs[0, 1].set_title("Prediction with\nconformal mask\n(our non-semantic $D_p$)", y=1.05)
axs[0, 2].set_title("Prediction with\nconformal mask\n(our semantic $D_p$)", y=1.05)
axs[0, 3].set_title("Prediction with\nconformal mask\n[Kutiel et al., 2023]", y=1.05)
axs[0, 4].set_title("Lower bound\nconformal mask\n[Angelopoulos\n et al., 2022b]", y=1.05)
axs[0, 5].set_title("Upper bound\nconformal mask\n[Angelopoulos\n et al., 2022b]", y=1.05)

fig.savefig(fig_output_path / "fig6l.png", dpi=150, bbox_inches='tight')
