# File path: sinsr_to_jax_flax.py
import jax
import jax.numpy as jnp
import numpy as np
from pathlib import Path
from flax import linen as nn
from imageio.v3 import imread
from skimage.color import rgb2lab
from tqdm import tqdm
from icecream import ic
from skimage.filters import gaussian


root_path = Path(__file__).parent.parent
dataset_base = 'LIU4K_v2_valid'
sinsr_path = root_path.parent / 'SinSR'
dataset_path = sinsr_path / "datasets" / dataset_base

colorizer_path = root_path.parent / "GCP-Colorization"

output_dir = colorizer_path / "results" / dataset_base
fake_dir = output_dir / "fake"
gt_dir = output_dir  / "gt"

masks_dir = root_path / "masks-colorization" / dataset_base
masks_dir.mkdir(parents=True, exist_ok=True)

RUNS = 7

@jax.jit
def normalize_lab(img):
    L, a, b = img[:, :, :, 0], img[:, :, :, 1], img[:, :, :, 2]

    L = L / 100
    a = (a + 128) / 255
    b = (b + 128) / 255

    return jnp.stack([L, a, b], axis=-1)

file_list = [fname.stem
        for fname in dataset_path.resolve().glob("**/*.png")
        if fname.is_file() and fname.parent.name != "Capture"
    ]
    
file_list.sort()

# @jax.jit
def calculate_mask(img, kernel_size, iters):
    """
    Calculates the variance mask for a batch of images using average pooling.
    Args:
        img: JAX array of shape [batch, height, width, channels].
        kernel_size: Size of the pooling window.
        iters: Number of iterations to perform.
    Returns:
        Variance mask as a JAX array of shape [height, width].
    """
    # pad_size = kernel_size // 2
    # var_mask = jnp.pad(img, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)),
    #                    mode='constant', constant_values=0)
    var_mask = img
    for _ in range(iters):
        # Perform average pooling using Flax
        avg_pool = nn.avg_pool(var_mask, window_shape=(kernel_size,
                                                       kernel_size),
                               strides=(1, 1), padding='SAME')

        # Variance: E[X^2] - (E[X])^2
        mean_square = nn.avg_pool(var_mask**2, window_shape=(kernel_size,
                                                             kernel_size),
                                  strides=(1, 1), padding='SAME')

        square_mean = avg_pool**2

        var_mask = mean_square - square_mean

    # Compute mean variance across spatial dimensions and return
    return jnp.mean(var_mask, axis=(0, 3))

if __name__ == '__main__':
    # Load generated images from SinSR

    for image in tqdm(file_list):
        image_path = Path(fake_dir) / image 
        # if subdataset + '/' + image_id not in valid_fnames:
        #     continue

        # Stack generated images
        generated_images = jnp.stack([
            imread(str(image_path) + f'_randdirectionnum{run}.png')
            for run in range(RUNS)
        ])

        original_image = imread(gt_dir / f'{image}_randdirectionnum0.png')
        original_image = jnp.array(rgb2lab(original_image), dtype=jnp.float32)

        # Convert images to Lab color space
        generated_images = rgb2lab(generated_images)
        # Convert to JAX tensor
        upscaled_tensor = jnp.array(generated_images, dtype=jnp.float32)
        # print(upscaled_tensor.shape)


        # Save first image for inspection
        jnp.save(masks_dir / f'{image}_pred.npy',
                upscaled_tensor[RUNS//2])

        upscaled_tensor = normalize_lab(upscaled_tensor)
        original_image = normalize_lab(original_image[None, ...])[0]

        diff = jnp.sum(jnp.abs(original_image - upscaled_tensor[0]), axis=-1)
        # jnp.save(mask_folder_out / f'{image_id}_k1_diff.npy', diff)

        for sigma_size in [16, 32, 64]:
            diff_mask = gaussian(diff, sigma=sigma_size)

            jnp.save(masks_dir / f'{image}_k{sigma_size}_diff.npy', diff_mask)

        varmask_simple_array = jnp.var(upscaled_tensor, axis=0).mean(axis=2)

        # ic(varmask_simple_array.shape)
        jnp.save(masks_dir / f'{image}_simple_varmask.npy',
                varmask_simple_array)

        var_mask_array = gaussian(jnp.var(upscaled_tensor,
                                          axis=0).mean(axis=2), sigma=64)
        jnp.save(masks_dir / f'{image}_s64_varmask.npy',
                    var_mask_array)
