"""
Runs a conformal risk control procedure on a colorization task.
"""
import json
from pathlib import Path

import numpy as np
from skimage.color import rgb2lab, gray2rgb
from imageio.v3 import imread
from tqdm import tqdm

from src.conformal import ConformalCalibratedModel

root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"


DATASET = "LIU4K_v2_train"

dataset_path = sinsr_path / "datasets" / DATASET

colorizer_path = root_path.parent / "GCP-Colorization"

output_dir = colorizer_path / "results" / DATASET
fake_dir = output_dir / "fake"
gt_dir = output_dir  / "gt"
gray_dir = output_dir  / "gray"

masks_dir = root_path / "masks-colorization" / DATASET


def my_rgb2lab(image_array: np.ndarray) -> np.ndarray:
    """
    Convert RGB image to Lab color space.

    Args:
        image_array (np.ndarray): Input RGB image array.

    Returns:
        np.ndarray: Converted Lab color space image array.
    """
    if len(image_array.shape) == 2:
        image_array = gray2rgb(image_array)
    return rgb2lab(image_array) / np.array(
        [[[100, 127 * 2, 127 * 2]]]) + np.array([[[0, 0.5, 0.5]]])


file_list = [fname.stem
        for fname in dataset_path.resolve().glob("**/*.png")
        if fname.is_file() and fname.parent.name != "Capture"
    ]

file_list.sort()
file_list = file_list[:100] + file_list[400:500] + file_list[800:900] + file_list[1200:1300]


gt_images = [
    imread(gt_dir / f"{i}_randdirectionnum0.png")
    for i in tqdm(file_list, desc="read ground truth")
]

gray_images = [
    my_rgb2lab(imread(gray_dir / f"{i}_randdirectionnum0.png"))
    for i in tqdm(file_list, desc="read gray image")
]

gt_images = [
    my_rgb2lab(img) for img in tqdm(gt_images, desc="prep ground truth")
]

preds = [
    np.load(masks_dir / f"{i}_pred.npy")
    for i in tqdm(file_list, desc="load preds")
]


diff_sigmas = [16, 32, 64]
kernel_names = ['s64', 'simple']

for diff_sigma in diff_sigmas:

    gdiffs = [
        np.load(masks_dir / f"{i}_k{diff_sigma}_diff.npy")
        for i in tqdm(file_list, desc="load diffs")
    ]


    for kernel_name in kernel_names:

        prob_masks = [
            np.load(masks_dir / f"{i}_{kernel_name}_varmask.npy")
            for i in tqdm(file_list, desc="load masks")
        ]

        # clip masks into .95 quantile
        gprob_masks = [np.clip(mask, 0, np.quantile(mask, 0.95))
                    for mask in tqdm(prob_masks, desc="clip masks")]

        # for mask_sigma in mask_sigmas:
        print(f"kernel_name: {kernel_name}, diff_sigma: {diff_sigma}")


        conformal = ConformalCalibratedModel.calibrate(
            None,
            list(zip(gray_images, gt_images)),
            zip(preds, gprob_masks),
            alphas=list(np.linspace(0.025, 0.50, 39)) + [2.5,],
            diffs=gdiffs,
            method="dynprog"
        )

        save_path = root_path / 'colorization'
        save_path.mkdir(exist_ok=True)

        save_path /= f"colorization-thresholds-{kernel_name}-d{diff_sigma}.json"
        save_path.parent.mkdir(parents=True, exist_ok=True)

        with open(save_path, "w", encoding='utf-8') as file:
            json.dump(conformal.thresholds, file)
