import os

import torch
import numpy as np

import lpips
import glob

from skimage.restoration import denoise_wavelet
import bm3d

lpips_fn = lpips.LPIPS(net="vgg")

def main(
        file_path="output_CIFAR10",
        mu=100,
        C=1,
        wavelet=False
):
    file_path = os.path.join(file_path, f"mu_{mu:04d}_C_{int(C)}_batch_*.npz")
    batch_files = glob.glob(file_path)

    sigma_val = C/mu

    mse = 0
    psnr = 0
    lpips = 0
    num_imgs = 0
    for i, batch_file in enumerate(batch_files):
        data = np.load(batch_file)
        imgs = data["img"]
        noisy = data["noisy"]
        kappa = np.float32(np.clip((np.linalg.norm(imgs.reshape(imgs.shape[0], -1), axis=1)/C).reshape(-1, 1, 1, 1), a_min=1.0, a_max=100000))

        noise_std = (C*sigma_val*kappa).reshape(-1)

        noisy = noisy.transpose(0,2,3,1)
        denoised = np.zeros_like(noisy)
        print("Start denoising")
        for i,noisy_img in enumerate(noisy):
            if wavelet:
                denoised[i] = (denoise_wavelet(noisy_img, sigma=noise_std[i], channel_axis=-1, convert2ycbcr=True, rescale_sigma=True))
            else:
                denoised[i] = (bm3d.bm3d(noisy_img, sigma_psd=noise_std[i]))
        print("Finished denoising")

        denoised = ((denoised*0.5)+0.5).transpose(0,3,1,2)
        mse_i, psnr_i = compute_psnr(imgs, denoised)
        mse += mse_i
        psnr += psnr_i

        lpips += compute_lpips(imgs, denoised)
    
        num_imgs += imgs.shape[0]

    print(f"----------- mu = {mu}, C = {C} -----------")
    print(f"MSE = {mse/num_imgs}, PSNR = {psnr/num_imgs}, LPIPS = {lpips/num_imgs}")
   
def compute_lpips(x, y):
    x = torch.tensor(x)
    y = torch.tensor(y)
    y -= y.min()
    y /= y.max()
    x = (x-0.5)*2
    y = (y-0.5)*2
    lpips = lpips_fn(x, y).sum()
    return lpips
                   
def compute_mse(x, y, reduce_sum=True):
    if reduce_sum:
        return np.sum(np.mean(np.square(x - y), axis=(1,2,3)))
    else:
        return np.mean(np.square(x - y), axis=(1,2,3))
    
def compute_psnr(x, y):
    mse = compute_mse(x, y, reduce_sum=False)
    psnr = (10*np.log10(1.0/mse))
    return mse.sum(), psnr.sum()

if __name__ == "__main__":
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]
    wavelet = False

    for mu in mus:
        main(
            mu=mu,
            wavelet=wavelet
        )