from numba import njit, prange
import argparse
import numpy as np
import os
from tqdm import tqdm, trange
from PIL import Image
from skimage import color
import pathlib
from pathlib import Path
import json
from icecream import ic

parser = argparse.ArgumentParser()
parser.add_argument("--residuals", action="store_true")
parser.add_argument("--test", action="store_true")
args = parser.parse_args()
test = args.test
residuals = args.residuals


# Configurable parameters and paths
root_path = Path(__file__).parent.parent
sinsr_path = root_path.parent / "SinSR"
DATASET_BASENAME = "LIU4K_v2"

# Subsets: first used for calibration, second for testing
subdatasets = ("train", "valid")

# Base path where HR images (.png) are stored
dataset_path = sinsr_path / "datasets"

# Base path where SR images (.npy) are stored
output_path = root_path / "output_sinsr"

masks_path = sinsr_path / "masks"

# Path to save output prediction interval results
diffs_path = root_path / "masks_angelopoulos_new"

# Create output directory if it doesn't exist
os.makedirs(diffs_path, exist_ok=True)

# Significance level (1-alpha is the confidence for intervals)
ALPHAS = [0.1]  + list(np.linspace(0.025, 0.50, 39))

# Determine directories for calibration and test sets
CALIB_SUB = subdatasets[0]
TEST_SUB = subdatasets[1] if len(subdatasets) > 1 else subdatasets[0]
masks_calib_dir = masks_path / f"{DATASET_BASENAME}_{CALIB_SUB}"
masks_test_dir = masks_path / f"{DATASET_BASENAME}_{TEST_SUB}"
calib_hr_dir = dataset_path / f"{DATASET_BASENAME}_{CALIB_SUB}"
calib_sr_dir = output_path / f"{DATASET_BASENAME}_{CALIB_SUB}"
test_hr_dir = dataset_path / f"{DATASET_BASENAME}_{TEST_SUB}"
test_sr_dir = output_path / f"{DATASET_BASENAME}_{TEST_SUB}"
test_out_dir = diffs_path / f"{DATASET_BASENAME}_{TEST_SUB}_intervals"

os.makedirs(test_out_dir, exist_ok=True)

# Define Lab color space value bounds for clamping (per channel)
LAB_MIN = np.array([0.0, -128.0, -128.0], dtype=np.float32)
LAB_MAX = np.array([100.0, 127.0, 127.0], dtype=np.float32)


def my_left_corner_crop(image_array, crop_size=2048):
    """
    Crop the image from the top-left corner to a square of size crop_size x crop_size.
    Assumes image_array is a NumPy array with shape (H, W, C) and H,W >= crop_size.
    """
    return image_array[:crop_size, :crop_size, ...]


def my_rgb2lab(image_array):
    """
    Convert an RGB image array to Lab color space.
    Input values should be in range [0, 1] if float or [0, 255] if uint8.
    Returns a float32 Lab image array.
    """
    # Ensure the image is float in [0,1]
    if image_array.dtype != np.float32 and image_array.dtype != np.float64:
        image_array = image_array.astype(np.float32)
    if image_array.max() > 1.0:
        image_array = image_array / 255.0  # scale 0-255 to 0-1
    # Convert to Lab color space
    lab_image = color.rgb2lab(image_array)
    return lab_image.astype(np.float32)


@njit(parallel=True)
def compute_abs_diff(hr, sr):
    """
    Compute the absolute difference between two images (hr and sr) element-wise.
    hr and sr should have the same shape (H, W, C).
    Returns a new array of absolute differences with the same shape.
    """
    diff = np.empty_like(hr, dtype=np.float32)
    for i in prange(hr.shape[0]):
        for j in range(hr.shape[1]):
            for k in range(hr.shape[2]):
                # Calculate absolute difference for each channel
                diff_val = sr[i, j, k] - hr[i, j, k]
                if diff_val < 0:
                    diff_val = -diff_val
                diff[i, j, k] = diff_val
    return diff


@njit(parallel=True)
def apply_threshold_array(sr, threshold, min_vals, max_vals):
    """
    Apply a symmetric threshold to each pixel value of an image to get lower and upper bounds.
    Clamps the resulting values to [min_vals, max_vals] for each channel (to stay in valid Lab range).
    Returns two arrays: lower bounds and upper bounds, both same shape as sr.
    """
    lower = np.empty_like(sr, dtype=np.float32)
    upper = np.empty_like(sr, dtype=np.float32)
    for i in prange(sr.shape[0]):
        for j in range(sr.shape[1]):
            for k in range(sr.shape[2]):
                val = sr[i, j, k]
                low_val = val - threshold
                high_val = val + threshold
                # Clamp to valid range per channel
                if low_val < min_vals[k]:
                    low_val = min_vals[k]
                if high_val > max_vals[k]:
                    high_val = max_vals[k]
                lower[i, j, k] = low_val
                upper[i, j, k] = high_val
    return lower, upper


# Identify calibration (validation) and test image filenames (without extension)
valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in masks_calib_dir.resolve().glob("**/*.png")
    if fname.is_file()
    and fname.parent.name != "Capture"
    and os.path.isfile(
        os.path.join(calib_sr_dir, fname.parent.name + "/" + fname.stem + "_pred.npy")
    )
]
valid_fnames.sort()

test_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in test_hr_dir.resolve().glob("**/*.png")
    if fname.is_file()
    and fname.parent.name != "Capture"
    and os.path.isfile(
        os.path.join(test_sr_dir, fname.parent.name + "/" + fname.stem + "_pred.npy")
    )
]
test_fnames.sort()

if residuals:
    # Calibration: compute residuals for all validation images and determine conformal threshold
    residuals_list = []
    print("Computing residuals on calibration set...")
    for base in tqdm(valid_fnames, desc="Calibration images"):
        # Load HR image (.png) and SR image (.npy)
        hr_path = os.path.join(calib_hr_dir, base + ".png")
        sr_path = os.path.join(calib_sr_dir, base + "_pred.npy")
        # Read and preprocess HR image
        hr_img = Image.open(hr_path).convert("RGB")
        hr_array = np.array(hr_img)
        hr_array = my_left_corner_crop(hr_array, crop_size=2048)
        hr_lab = (my_rgb2lab(hr_array)  - LAB_MIN) / (
            LAB_MAX - LAB_MIN)
        # Read and preprocess SR image
        sr_array = np.load(sr_path)
        sr_lab = (sr_array  - LAB_MIN) / (
            LAB_MAX - LAB_MIN)
        # Compute pixel-wise absolute difference (residuals)
        diff = np.abs(sr_lab - hr_lab)
        # Flatten residuals and collect
        residuals_list.append(diff.ravel())
        
        save_path = diffs_path / f"{base}_residuals.npy"
        os.makedirs(save_path.parent, exist_ok=True)
        np.save(save_path, diff.ravel())

if not test:
    
    if not residuals:
        print("Loading residuals for calibration set...")
        # Load residuals from .npy files
        residuals_list = []
        for base in tqdm(valid_fnames, desc="Loading residuals"):
            save_path = diffs_path / f"{base}_residuals.npy"
            residuals = np.load(save_path)
            residuals_list.append(residuals)
    thresholds = {}
    print("Calculating conformal thresholds...")
    for alpha in tqdm(ALPHAS):
        n = len(residuals_list)

        threshold = 1.0  
        delta_threshold = 0.02 
        delta = 0.05
        while True:
            total_loss = 0
            for i in range(n):
                # [pred_{i,j} - threshold, pred_{i,j} + threshold]
                total_loss = total_loss + (1 - np.mean(residuals_list[i] <= threshold))
            loss_bound = total_loss / n + np.sqrt(np.log(1 / delta) / (2 * n))
            if loss_bound > alpha:
                break
            threshold = threshold - delta_threshold

        thresholds[alpha] = threshold

        print(f"Conformal threshold: {threshold:.4f}")

        # Save thresholds to JSON file
        thresholds_path = diffs_path / f"{DATASET_BASENAME}_{CALIB_SUB}_thresholds.json"
        with open(thresholds_path, "w", encoding="utf-8") as f:
            json.dump(thresholds, f, indent=4)
    
    print(f"Conformal thresholds have been saved to: {thresholds_path}")

else:
    print("Loading conformal thresholds for test set...")
    # Load thresholds from JSON file
    thresholds_path = diffs_path / f"{DATASET_BASENAME}_{CALIB_SUB}_thresholds.json"
    with open(thresholds_path, "r", encoding="utf-8") as f:
        thresholds = json.load(f)

    # Application: construct prediction intervals for each test image
    print("Applying conformal prediction to test set images...")
    chosen_alpha = next(iter(thresholds.keys()))  # Use the first alpha for prediction intervals
    chosen_threshold = thresholds[
        chosen_alpha
    ]  # Use the first alpha for prediction intervals
    for base in tqdm(test_fnames, desc="Test images"):
        sr_path = test_sr_dir / f"{base}_pred.npy"
        tqdm.write(f"Processing {base} with threshold {chosen_threshold:.4f}")
        # Make dir if it doesn't exist
        os.makedirs(test_out_dir / base, exist_ok=True)
        
        # Load and preprocess SR image (no need to load HR for prediction intervals)
        sr_array = np.load(sr_path)
        sr_lab = (sr_array - LAB_MIN) / (
            LAB_MAX - LAB_MIN)
        # Apply threshold to get lower and upper bound images
        lower_bound, upper_bound = apply_threshold_array(
            sr_lab, np.float32(chosen_threshold), LAB_MIN, LAB_MAX
        )
        # Save the lower and upper bound arrays
        np.save(test_out_dir / f"{base}_{chosen_alpha}_lower.npy", lower_bound)
        np.save(test_out_dir / f"{base}_{chosen_alpha}_upper.npy", upper_bound)
        # Optionally, compute and save the prediction interval width (uncertainty) map
        interval_width = upper_bound - lower_bound  # width per pixel (same shape as image)
        np.save(test_out_dir / f"{base}_width.npy", interval_width)

    print(f"Conformal prediction intervals have been saved to: {test_out_dir}")
