import pathlib
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from imageio.v3 import imread
import json, os
from skimage.color import lab2rgb
from matplotlib.patches import Rectangle


matplotlib.rc("text", usetex=True)
matplotlib.rc("text.latex", preamble=r"\usepackage{amsmath}")
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams["font.family"] = "Times"


def my_left_corner_crop(x):
    return x[:2048, :2048]


root_path = pathlib.Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)

valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()

chosen_ids = [
    171,
]  # 94, 32]
IS = [valid_fnames[i] for i in chosen_ids]
I = IS[0]

zooms = [4, 8, 16]
LEVEL = "0.0875"  # "0.1"
BORDER_COLOR = "yellow"
ALPHA = 0.3
with open(root_path / "conformal/sinsr-thresholds-s64-dp64-bs.json") as file:
    THRESHOLD = json.load(file)[LEVEL]

fig, axs = plt.subplots(
    1,
    5,
    figsize=(10, 1.9),
    gridspec_kw={
        "wspace": 0.05,  # horizontal spacing between plots
        "hspace": 0.05,  # vertical spacing between plots
    },
)

zoom = 16
roi = lambda x: (
    x[((zoom - 1) * x.shape[0]) // zoom : x.shape[0], 0 : x.shape[1] // zoom]
)

real_image = my_left_corner_crop(imread(dataset_path / f"{I}.png")) / 255

low_res = rescale(real_image, 0.25, channel_axis=2)
pred = lab2rgb(np.load(output_path / f"{I}_pred.npy"))
mask = np.load(output_path / f"{I}_s64_varmask.npy")


red_mask = np.stack(
    (np.ones((2048, 2048)), np.zeros((2048, 2048)), np.zeros((2048, 2048))),
    axis=2,
)
selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD
pred_with_mask = np.where(selection, (1 - ALPHA) * red_mask + ALPHA * pred, pred)


ALPHA = 0.3
for i, img in enumerate([mask, low_res, real_image, pred, pred_with_mask]):
    w, h = mask.shape if img is not low_res else low_res.shape[:2]
    x1, x2, y1, y2 = w // 16, 2 * w // 16, w - w // 32, 15 * w // 16 - w // 32
    axs[i].imshow(img)
    axins = axs[i].inset_axes(
        [0.4, 0.4, 0.55, 0.55],
        xlim=(x1, x2),
        ylim=(y1, y2),
        xticklabels=[],
        yticklabels=[],
    )
    axins.spines[:].set_color("white")
    axins.tick_params(axis="both", colors="white")

    axins.imshow(img)
    axs[i].indicate_inset_zoom(axins, edgecolor="white")


for ax in np.ravel(axs):
    ax.set_xticks([])
    ax.set_yticks([])


axs[0].set_title("$\sigma$ with 64-pixel-wide\nGaussian blur", y=1.05)
axs[1].set_title("Low resolution", y=1.05)
axs[2].set_title("Ground truth", y=1.05)
axs[3].set_title("Prediction without\nconformal mask", y=1.05)
axs[4].set_title("Prediction with\nconformal mask", y=1.05)

fig.savefig(fig_output_path / "fig3.png", dpi=150, bbox_inches="tight")
