r"""Richardson-Lucy Deconvolution Method"""

from PIL import Image
from pathlib import Path
import numpy as np
import cv2
from torch.utils import data
from torch import Tensor
import matplotlib.pyplot as plt
from scipy.signal import convolve2d

import zipfile
import io

from torch_fidelity.fidelity import calculate_metrics
from torch import Tensor
from torch.utils import data
import torchvision


from utils import *
from datasets import Array3D, Features, load_from_disk

from functools import partial

class ZipDataset(data.Dataset):
    r"""Zip image dataset."""

    def __init__(self, archive: Path):
        self.images = []

        with zipfile.ZipFile(archive, mode='r') as file:
            for name in file.namelist():
                with file.open(name) as data:
                    img = Image.open(data)
                    img = img.convert('RGB')

                self.images.append(img)

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, i: int) -> Tensor:
        return torchvision.transforms.PILToTensor()(self.images[i]) # pil_to_tensor(self.images[i])

    @staticmethod
    def zip(archive: Path, images: List):
        with zipfile.ZipFile(archive, mode='w') as file:
            for i, img in enumerate(images):
                buffer = io.BytesIO()
                img.save(buffer, 'png')
                file.writestr(f'IMG_{i}.png', buffer.getvalue())


# Helper: Gaussian PSF (Point Spread Function)
def gaussian_kernel(sigma):
    size = 2 * int(3 * sigma) + 1
    ax = np.arange(-size // 2 + 1., size // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    kernel = np.exp(-(xx**2 + yy**2) / (2. * sigma**2))
    return kernel / np.sum(kernel)

# Richardson–Lucy algorithm
def richardson_lucy(image, psf, iterations=30):
    image = np.maximum(image, 1e-8)
    estimate = np.full_like(image, 0.5)
    psf_mirror = psf[::-1, ::-1]
    
    for _ in range(iterations):
        conv = convolve2d(estimate, psf, mode='same', boundary='wrap')
        relative_blur = image / np.maximum(conv, 1e-8)
        estimate *= convolve2d(relative_blur, psf_mirror, mode='same', boundary='wrap')
    
    return estimate

class Counter:
    def __init__(self):
        self.cnt = 0
        
    def inc(self):
        self.cnt += 1

    def get(self):
        return self.cnt

def deblur_transform(row, psf, save_path, counter):
    x = row['y'] # image pixel values are in [-2, 2]
    x = (x + 2) / 4 # image pixel values are in [0, 1]
    deblurred = np.zeros_like(x)
    for i in range(3):
        deblurred[:, :, i] = richardson_lucy(x[:, :, i], psf, iterations=50)
    deblurred = deblurred * 4 - 2 # image pixel values are in [-2, 2]
    to_pil(deblurred).save(save_path / f'{counter.get()}.png')
    counter.inc()
    return {'y': deblurred}

images = []

def deblur_transform2(row, psf):
    x = row['y'] # image pixel values are in [-2, 2]
    x = (x + 2) / 4 # image pixel values are in [0, 1]
    deblurred = np.zeros_like(x)
    for i in range(3):
        deblurred[:, :, i] = richardson_lucy(x[:, :, i], psf, iterations=50)
    deblurred = deblurred * 4 - 2 # image pixel values are in [-2, 2]
    images.append(to_pil(deblurred))
    return {'y': deblurred}

def main(blurry_dataset: Path, save_path: Path):
    save_path.mkdir(exist_ok = True, parents = True)

    dataset = load_from_disk(blurry_dataset)
    dataset.set_format('numpy')
    dataset = dataset['train']

    sigma = 2
    psf = gaussian_kernel(sigma)

    types = {'y': Array3D(shape=(32, 32, 3), dtype='float32')}

    counter = Counter()

    dataset.map(
        partial(deblur_transform, psf = psf, save_path = save_path, counter = counter),
        features=Features(types),
    )

def measure(blurry_dataset: Path, archive: str):
    dataset = load_from_disk(blurry_dataset)
    dataset.set_format('numpy')
    dataset = dataset['train'].select(range(49999))

    sigma = 2
    psf = gaussian_kernel(sigma)

    types = {'y': Array3D(shape=(32, 32, 3), dtype='float32')}

    dataset.map(
        partial(deblur_transform2, psf = psf),
        features=Features(types)
    )

    breakpoint()

    ZipDataset.zip(archive, images)
    stats = calculate_metrics(
        input1=ZipDataset(archive),
        input2='cifar10-train',
        fid=True,
        isc=True,
    )

if __name__ == "__main__":
    corruption_severity = 2

    measure(
        archive = '/data/vision/___/scratch/___ht/cifar_dir/datasets_for_eval/conditional/archive_rl_method.zip',
        blurry_dataset = Path(f'/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-gaussian-blur-{corruption_severity}'),
        )
    
    # main(
    #     blurry_dataset = Path(f'/data/vision/___/scratch/___ht/cifar_dir/hf/cifar-mask-gaussian-blur-{corruption_severity}'),
    #     save_path = Path(f'/data/vision/___/scratch/___ht/cifar_dir/datasets_for_eval/conditional/RL_method')
    #     )