import pathlib
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from imageio.v3 import imread
import json, os
from skimage.color import lab2rgb

matplotlib.rc('text', usetex=True)
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}')
plt.rcParams.update({"figure.dpi": 300})
plt.rcParams.update({"font.size": 25})
plt.rcParams["font.family"] = "Times"


root_path = pathlib.Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


DATASET = "LIU4K_v2_valid"

dataset_path = sinsr_path / "datasets" / DATASET

colorizer_path = root_path.parent / "GCP-Colorization"

output_dir = colorizer_path / "results" / DATASET
fake_dir = output_dir / "fake"
gt_dir = output_dir  / "gt"
gray_dir = output_dir  / "gray"

masks_dir = root_path / "masks-colorization" / DATASET

fig_output_path = root_path / "figs-neurips"
fig_output_path.mkdir(exist_ok=True)

valid_fnames = [
    fname.stem
    for fname in dataset_path.resolve().glob("**/*.png")
    if fname.is_file() and fname.parent.name != "Capture"
]
valid_fnames.sort()


chosen_ids = [118, 185, 85, 88, 34, 172] 


IS = [valid_fnames[i] for i in chosen_ids]

LEVEL = "0.1" 
with open(root_path / "colorization/colorization-thresholds-s64-d64.json") as file:
    f = json.load(file)
    THRESHOLD = f[LEVEL]

n_rows = len(IS)
n_cols = 5

fig, axs = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4.25*n_rows), gridspec_kw={
        'wspace': 0.125,  # horizontal spacing between plots
        'hspace': 0.125   # vertical spacing between plots
    }, constrained_layout=True)

for row, I in enumerate(IS):

    real_image = (
        imread(gt_dir / f"{I}_randdirectionnum0.png") / 255
    )
    
    low_res = (
        imread(gray_dir / f"{I}_randdirectionnum0.png") / 255
    )
    pred = (
        lab2rgb(np.load(masks_dir / f"{I}_pred.npy"))
        )
    mask = np.load(masks_dir / f"{I}_s64_varmask.npy")

    ALPHA = 0.3

    axs[row, 0].imshow(mask)
    axs[row, 1].imshow(low_res)
    axs[row, 2].imshow(real_image)
    red_mask = np.stack(
        (np.ones((256, 256)), np.zeros((256, 256)), np.zeros((256, 256))),
        axis=2,
    )
    selection = np.stack((mask, mask, mask), axis=2) >= THRESHOLD
    axs[row, 3].imshow(
        np.where(selection, (1 - ALPHA) * red_mask + ALPHA * pred, pred)
    )
    axs[row, 4].imshow(pred)


for ax in np.ravel(axs):
    ax.set_xticks([])
    ax.set_yticks([])

axs[0, 0].set_title(r"$\sigma$ with 64-pixel-wide"+"\nGaussian blur", y=1.05)
axs[0, 1].set_title("Gray scale", y=1.05)
axs[0, 2].set_title("Ground truth", y=1.05)
axs[0, 3].set_title(f"Prediction with\nconformal mask", y=1.05) 
axs[0, 4].set_title("Prediction without\nconformal mask", y=1.05)


fig.savefig(fig_output_path / "fig-col.png", bbox_inches='tight', dpi=150)
