import pathlib
import os
import json
from torchvision.transforms import CenterCrop
from PIL import Image
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from skimage.filters import gaussian
from icecream import ic 
import cv2 as cv
from skimage.transform import rescale
from imageio.v3 import imread
from skimage.color import lab2rgb

root_path = pathlib.Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


dataset_basename = "LIU4K_v2_valid"
subdatasets = ["Animal", "Building", "Mountain", "Street"]

downscaled_path = sinsr_path / "datasets" / (dataset_basename + "_downscaled")
output_path = root_path / "output_sinsr" / dataset_basename
dataset_path = sinsr_path / "datasets" / dataset_basename
diffs_path = sinsr_path / "masks" / dataset_basename
fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)

valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()

fontsize = 24
matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 300})
plt.rcParams.update({"font.size": fontsize})
plt.rcParams["font.family"] = "Times New Roman"

I = valid_fnames[199] # 36, 47
LEVEL = '0.0875'
blurs = (16, 32, 64)
ALPHA = 0.3

pred_img = lab2rgb(np.load(output_path / f"{I}_pred.npy"))
base_varmask_img = np.load(output_path / f"{I}_ker31_varmask.npy")
varmasks = [gaussian(base_varmask_img, sigma=blur) for blur in blurs]

thresholds = np.zeros((4, 3))
for i, dp in enumerate((1, 16, 32, 64)):
    for j, blur in enumerate(blurs):
        with open(root_path / f"conformal/sinsr-thresholds-k31-dp{dp}-m{blur}-bs-fix.json") as file:
            thresholds[i, j] = json.load(file)[LEVEL]


fig, ax = plt.subplots(3, 4, figsize=(16, 12), gridspec_kw={
        'wspace': 0.1,  # horizontal spacing between plots
        'hspace': 0.1   # vertical spacing between plots
    }, constrained_layout=True)

for i, dp in enumerate((1, 16, 32, 64)):
    for j, blur in enumerate(blurs):
        roi = lambda x: (
            x[128 : 2048 - 128, 128 : 2048 - 128]
        )
        red_mask = np.stack(
            (np.ones((2048, 2048)), np.zeros((2048, 2048)), np.zeros((2048, 2048))),
            axis=2,
        )
        selection = np.stack((varmasks[j], varmasks[j], varmasks[j]), axis=2) >= thresholds[i, j]
        mask = np.where(selection,
                        (1 - ALPHA) * red_mask + ALPHA * pred_img,
                        pred_img)
        ax[j, i].imshow(roi(mask))
        ax[j, i].set_xticks([])
        ax[j, i].set_yticks([])


ax[0, 0].set_ylabel("16-pixel-wide blur\nfor $\sigma_{31}$", fontsize=fontsize)
ax[1, 0].set_ylabel("32-pixel-wide blur\nfor $\sigma_{31}$", fontsize=fontsize)
ax[2, 0].set_ylabel("64-pixel-wide blur\nfor $\sigma_{31}$", fontsize=fontsize)

ax[0, 0].set_title("1-pixel-wide blur\nfor $D_p$", fontsize=fontsize)
ax[0, 1].set_title("16-pixel-wide blur\nfor $D_p$", fontsize=fontsize)
ax[0, 2].set_title("32-pixel-wide blur\nfor $D_p$", fontsize=fontsize)
ax[0, 3].set_title("64-pixel-wide blur\nfor $D_p$", fontsize=fontsize)

plt.savefig(fig_output_path / "fig5.png", dpi=150, bbox_inches='tight')