import jax
import jax.numpy as jnp
import numpy as np
from pathlib import Path
from flax import linen as nn
from imageio.v3 import imread
from skimage.color import rgb2lab
from tqdm import tqdm
from icecream import ic
from skimage.filters import gaussian

base_path = Path(__file__).parent.parent
sinsr_path = base_path.parent / 'SinSR'

dataset_base = 'LIU4K_v2_valid'
subdatasets = ["Animal", "Building", "Mountain", "Street"]
diffs_path = sinsr_path / "masks" / dataset_base

runs = 9

def my_left_corner_crop(x):
    return x[:2048, :2048]

def save_images(images, save_path: Path, filename: str = ''):
    directory = save_path / f"{filename}"
    directory.mkdir(exist_ok=True)
    
    for i, img in enumerate(images):
        np.save(directory / f"upscaled_{i}.npy", img)


@jax.jit
def normalize_lab(img):
    L, a, b = img[:, :, :, 0], img[:, :, :, 1], img[:, :, :, 2]

    L = L / 100
    a = (a + 128) / 255
    b = (b + 128) / 255

    return jnp.stack([L, a, b], axis=-1)

valid_fnames = [
    fname.parent.name + "/" + fname.stem
    for fname in diffs_path.resolve().glob("**/*.png")
    if fname.is_file()
]

# @jax.jit
def calculate_mask(img, kernel_size, iters):
    """
    Calculates the variance mask for a batch of images using average pooling.
    Args:
        img: JAX array of shape [batch, height, width, channels].
        kernel_size: Size of the pooling window.
        iters: Number of iterations to perform.
    Returns:
        Variance mask as a JAX array of shape [height, width].
    """
    # pad_size = kernel_size // 2
    # var_mask = jnp.pad(img, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)),
    #                    mode='constant', constant_values=0)
    var_mask = img
    for _ in range(iters):
        # Perform average pooling using Flax
        avg_pool = nn.avg_pool(var_mask, window_shape=(kernel_size,
                                                       kernel_size),
                               strides=(1, 1), padding='SAME')

        # Variance: E[X^2] - (E[X])^2
        mean_square = nn.avg_pool(var_mask**2, window_shape=(kernel_size,
                                                             kernel_size),
                                  strides=(1, 1), padding='SAME')

        square_mean = avg_pool**2

        var_mask = mean_square - square_mean

    # Compute mean variance across spatial dimensions and return
    return jnp.mean(var_mask, axis=(0, 3))


if __name__ == '__main__':
    # Load generated images from SinSR
    for subdataset in subdatasets:
        dataset_fullname = dataset_base + '/' + subdataset

        result_folder_in = sinsr_path / 'output' / dataset_fullname
        first_result_folder_in = result_folder_in / "output_0"

        mask_folder_out = base_path / 'output_sinsr' / dataset_fullname
        mask_folder_out.mkdir(parents=True, exist_ok=True)

        original_folder = sinsr_path / 'datasets' / dataset_fullname

        for image in tqdm(first_result_folder_in.glob('*.png')):
            image_id = image.stem
            # if subdataset + '/' + image_id not in valid_fnames:
            #     continue

            # Stack generated images
            generated_images = jnp.stack([
                imread(str(result_folder_in / f"output_{run}" / f'{image_id}.png'))
                for run in range(runs)
            ])

            original_image = my_left_corner_crop(imread(original_folder / f'{image_id}.png'))
            original_image = jnp.array(rgb2lab(original_image), dtype=jnp.float32)

            # Convert images to Lab color space
            generated_images = rgb2lab(generated_images)
            # Convert to JAX tensor
            upscaled_tensor = jnp.array(generated_images, dtype=jnp.float32)
            # print(upscaled_tensor.shape)


            # Save first image for inspection
            jnp.save(mask_folder_out / f'{image_id}_pred.npy',
                    upscaled_tensor[0])

            upscaled_tensor = normalize_lab(upscaled_tensor)
            original_image = normalize_lab(original_image[None, ...])[0]

            diff = jnp.sum(jnp.abs(original_image - upscaled_tensor[0]), axis=-1)
            jnp.save(mask_folder_out / f'{image_id}_k1_diff.npy', diff)

            for sigma_size in [16, 32, 64]:
                diff_mask = gaussian(diff, sigma=sigma_size)

                jnp.save(mask_folder_out / f'{image_id}_k{sigma_size}_diff.npy', diff_mask)

            varmask_simple_array = jnp.var(upscaled_tensor, axis=0).mean(axis=2)
            # ic(varmask_simple_array.shape)
            jnp.save(mask_folder_out / f'{image_id}_simple_varmask.npy',
                    varmask_simple_array)

            # Compute variance masks
            iters = 1
            for kernel_size in [19, 25, 31]:
                var_mask = upscaled_tensor
                var_mask_array = calculate_mask(var_mask, kernel_size, iters)
                np.save(mask_folder_out / f'{image_id}_ker{kernel_size}_varmask.npy',
                        np.array(var_mask_array))


            var_mask_array = gaussian(jnp.var(upscaled_tensor,
                                              axis=0).mean(axis=2), sigma=64)
            jnp.save(mask_folder_out / f'{image_id}_s64_varmask.npy',
                        var_mask_array)
