import os

import torch
import numpy as np

import lpips

from tqdm import tqdm

from torchmetrics.image import StructuralSimilarityIndexMeasure

ssim_fn = StructuralSimilarityIndexMeasure(reduction='none', data_range=1.0)
lpips_fn = lpips.LPIPS(net="vgg")

def main(
        file_path="output_CIFAR10",
        mu=100,
        C=1,
        prior_size=256,
        file_idx=0
):
    print(f"----------- mu = {mu}, C = {C}, file_idx = {file_idx} -----------")
    batch_file = os.path.join(file_path, f"mu_{mu:04d}_C_{int(C)}_batch_{file_idx}.npz")

    success = 0
    total_n = 0

    data = np.load(batch_file)
    imgs = data["img"]
    noisy = data["noisy"]
    kappa = np.clip((np.linalg.norm(imgs.reshape(imgs.shape[0], -1), axis=1)/C).reshape(-1, 1, 1, 1), a_min=1.0, a_max=100000)
    img_clip = np.float32(imgs/kappa)
    imgs = imgs*0.5+0.5
    noisy_clip = np.float32(noisy/kappa)
    noisy = (noisy*0.5)+0.5

    observation = noisy

    pred_imgs = np.float32(np.ones(imgs.shape))

    for j in (pbar := tqdm(range(imgs.shape[0]))):
        prior_set = img_clip[:prior_size]
        prior_set[0] = img_clip[j]
        prior_set_noC = imgs[:prior_size]
        prior_set_noC[0] = imgs[j]
        observation = noisy_clip
        pred_id = compute_dotProduct(prior_set, observation[None, j]).argmax()
        pred_imgs[j] = np.float32(prior_set_noC[pred_id])
        if pred_id == 0:
            success += 1
        total_n += 1
        pbar.set_description(f"Success rate = {success/total_n}")

    mse = compute_mse(pred_imgs, imgs)
    mse = mse.mean()
    ssim = compute_ssim(pred_imgs, imgs).mean()
    lpips = compute_lpips(pred_imgs, imgs).mean()
        
    print(f"Success rate = {success/total_n}, MSE = {mse}, SSIM = {ssim}, LPIPS = {lpips}")

def compute_dotProduct(x,y):
    dot_product = (x*y).sum(axis=(1,2,3))
    return dot_product

def compute_l2Dist(x,y):
    l2_dist = ((x-y)**2).sum(axis=(1,2,3))
    return l2_dist

@torch.no_grad()
def compute_lpips(x, y):
    x = torch.tensor(x)
    y = torch.tensor(y)
    y -= y.min()
    y /= y.max()
    x = (x-0.5)*2
    y = (y-0.5)*2
    lpips = lpips_fn(x, y).reshape(-1)
    return lpips

def compute_mse(x, y, reduce_sum=False):
    if reduce_sum:
        return np.sum(np.mean(np.square(x - y), axis=(1,2,3)))
    else:
        return np.mean(np.square(x - y), axis=(1,2,3))
    
def compute_psnr(x, y):
    mse = compute_mse(x, y, reduce_sum=False)
    psnr = (10*np.log10(1.0/mse))
    return mse, psnr

def compute_ssim(x, y):
    return ssim_fn(torch.tensor(x), torch.tensor(y))

if __name__ == "__main__":
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    for mu in mus:
        main(
            mu=mu
        )