import os

import torch
import numpy as np
import glob

from torchmetrics.image import StructuralSimilarityIndexMeasure
import lpips

ssim_fn = StructuralSimilarityIndexMeasure(reduction='none', data_range=1.0)
lpips_fn = lpips.LPIPS(net="vgg")

def main(
        file_path="output_CIFAR10",
        mu=100,
        C=1,
        test_noisy=False
):
    print(f"----------- mu = {mu}, C = {C} -----------")
    file_path = os.path.join(file_path, f"mu_{mu:04d}_C_{int(C)}_batch_*.npz")
    batch_files = glob.glob(file_path)

    mse = 0
    psnr = 0
    ssim = 0
    lpips = 0
    num_imgs = 0
    for i, batch_file in enumerate(batch_files):
        data = np.load(batch_file)
        imgs = data["img"]
        noisy = data["noisy"]
        denoised = data["denoised"]

        imgs = imgs*0.5+0.5
        noisy = (noisy*0.5)+0.5
        denoised = (denoised*0.5)+0.5

        if test_noisy:
            denoised = noisy

        mse_i, psnr_i = compute_psnr(imgs, denoised)
        mse += mse_i
        psnr += psnr_i

        if test_noisy:
            denoised = noisy.clip(0, 1)

        ssim += compute_ssim(imgs, denoised)
        lpips += compute_lpips(imgs, denoised)
    
        num_imgs += imgs.shape[0]

    print(f"MSE = {mse/num_imgs}, PSNR = {psnr/num_imgs}, SSIM = {ssim/num_imgs}, LPIPS = {lpips/num_imgs}")

def compute_ssim(x, y, reduce_sum=True):
    if reduce_sum:
        return ssim_fn(torch.tensor(x), torch.tensor(y)).sum()
    else:
        return ssim_fn(x, y)
    
@torch.no_grad()
def compute_lpips(x, y):
    x = torch.tensor(x)
    y = torch.tensor(y)
    y -= y.min()
    y /= y.max()
    x = (x-0.5)*2
    y = (y-0.5)*2
    lpips = lpips_fn(x, y).sum()
    return lpips
                   
def compute_mse(x, y, reduce_sum=True):
    if reduce_sum:
        return np.sum(np.mean(np.square(x - y), axis=(1,2,3)))
    else:
        return np.mean(np.square(x - y), axis=(1,2,3))
    
def compute_psnr(x, y):
    mse = compute_mse(x, y, reduce_sum=False)
    psnr = (10*np.log10(1.0/mse))
    return mse.sum(), psnr.sum()

if __name__ == "__main__":
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    for mu in mus:
        main(
            mu=mu,
            test_noisy=False
        )