import os

import torch
import numpy as np

import lpips

from tqdm import tqdm

from torchmetrics.image import StructuralSimilarityIndexMeasure

ssim_fn = StructuralSimilarityIndexMeasure(reduction='none', data_range=1.0)
lpips_fn = lpips.LPIPS(net="vgg")

def main(
        file_path="output_CIFAR10",
        mu=100,
        C=1,
        prior_size=256,
        file_idx=0,
        test_noisy=False,
        dot_product=False
):
    print(f"----------- mu = {mu}, C = {C}, file_idx = {file_idx} -----------")
    batch_file = os.path.join(file_path, f"mu_{mu:04d}_C_{int(C)}_batch_{file_idx}.npz")

    success = 0
    total_n = 0

    data = np.load(batch_file)
    imgs = data["img"]
    noisy = data["noisy"]
    denoised = data["denoised"]
    kappa = np.clip((np.linalg.norm(imgs.reshape(imgs.shape[0], -1), axis=1)/C).reshape(-1, 1, 1, 1), a_min=1.0, a_max=100000)
    img_clip = (imgs/kappa)
    imgs = imgs*0.5+0.5
    noisy_clip = (noisy/kappa)
    noisy = (noisy*0.5)+0.5
    denoised = (denoised*0.5)+0.5

    if test_noisy:
        observation = noisy#.clip(0, 1)
    else:
        observation = denoised

    for j in (pbar := tqdm(range(imgs.shape[0]))):
        
        if dot_product:
            prior_set = img_clip[:prior_size]
            prior_set[0] = img_clip[j]
            observation = noisy_clip
            pred_id = compute_dotProduct(prior_set, observation[None, j]).argmax()
        else:
            prior_set = imgs[:prior_size]
            prior_set[0] = imgs[j]

            if test_noisy:
                pred_id = compute_l2Dist(prior_set, observation[None, j]).argmin()
            else:
                pred_id = compute_lpips(prior_set, observation[j]).argmin()
        if pred_id == 0:
            success += 1
        total_n += 1
        pbar.set_description(f"Success rate = {success/total_n}")
        
    print(f"Success rate = {success/total_n}")

def compute_dotProduct(x,y):
    dot_product = (x*y).sum(axis=(1,2,3))
    return dot_product

def compute_l2Dist(x,y):
    l2_dist = ((x-y)**2).sum(axis=(1,2,3))
    return l2_dist

@torch.no_grad()
def compute_lpips(x, y):
    x = torch.tensor(x)
    y = torch.tensor(y)
    y -= y.min()
    y /= y.max()
    x = (x-0.5)*2
    y = (y-0.5)*2
    lpips = lpips_fn(x, y).reshape(-1)
    return lpips

if __name__ == "__main__":
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    for mu in mus:
        main(
            mu=mu,
            test_noisy=False,
            dot_product=False
        )