import pathlib
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from imageio.v3 import imread
import json, os
from skimage.color import lab2rgb

matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 300})
plt.rcParams.update({"font.size": 24})
plt.rcParams["font.family"] = "Times"

def my_left_corner_crop(x):
    return x[:2048, :2048]

root_path = pathlib.Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)

valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()

chosen_ids = [185, 83, 132] 

IS = [valid_fnames[i] for i in chosen_ids]
LEVEL_1 = "0.18749999999999997" 
with open(root_path / "conformal/sinsr-thresholds-s64-d128-bs.json") as file:
    THRESHOLD_1 = json.load(file)[LEVEL_1]
    
LEVEL_2 = "0.1" # "0.0875"
with open(root_path / "conformal/sinsr-thresholds-s64-dp64-bs.json") as file:
    THRESHOLD_2 = json.load(file)[LEVEL_2]

n_rows = len(IS)
n_cols = 6

fig, axs = plt.subplots(n_rows, n_cols, figsize=(24, 4.25*n_rows), gridspec_kw={
        'wspace': 0.125,  # horizontal spacing between plots
        'hspace': 0.125   # vertical spacing between plots
    }, constrained_layout=True)

for row, I in enumerate(IS):

    real_image = (
        my_left_corner_crop(imread(dataset_path / f"{I}.png")) / 255
    )
    
    low_res = rescale(real_image, 0.25, channel_axis=2)
    pred = (
        lab2rgb(np.load(output_path / f"{I}_pred.npy"))
        )
    mask = np.load(output_path / f"{I}_s64_varmask.npy")

    ALPHA = 0.3

    axs[row, 0].imshow(mask)
    axs[row, 1].imshow(low_res)
    axs[row, 2].imshow(real_image)
    red_mask = np.stack(
        (np.ones((2048, 2048)), np.zeros((2048, 2048)), np.zeros((2048, 2048))),
        axis=2,
    )
    selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD_2
    axs[row, 3].imshow(
        np.where(selection, (1 - ALPHA) * red_mask + ALPHA * pred, pred)
    )
    selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD_1
    axs[row, 4].imshow(
        np.where(selection, (1 - ALPHA) * red_mask + ALPHA * pred, pred)
    )
    axs[row, 5].imshow(pred)


for ax in np.ravel(axs):
    ax.set_xticks([])
    ax.set_yticks([])
    
axs[0, 0].set_title(r"$\sigma$ with 64-pixel-wide"+"\nGaussian blur", y=1.05)
axs[0, 1].set_title("Low resolution", y=1.05)
axs[0, 2].set_title("Ground truth", y=1.05)
axs[0, 3].set_title("Prediction with\nconformal mask\n(non-semantic $D_p$)", y=1.05)
axs[0, 4].set_title("Prediction with\nconformal mask\n(semantic $D_p$)", y=1.05)
axs[0, 5].set_title("Prediction without\nconformal mask", y=1.05)


fig.savefig(fig_output_path / "fig1.png", bbox_inches='tight', dpi=150)
