import os

import torch
import numpy as np

import lpips

import glob

def main(
    lpips_fn,
    device,
    file_path="output_CIFAR10_Var",
    mu=100,
    C=1
):
    file_path = os.path.join(file_path, f"mu_{mu:04d}_C_1_run*.npz")
    run_files = glob.glob(file_path)
    print(run_files)

    denoised = None
    for i, batch_file in enumerate(run_files):
        data = np.load(batch_file)
        denoised_i = data["denoised"]*0.5+0.5
        if denoised is None:
            denoised = np.ones((len(run_files), *denoised_i.shape), dtype=np.float32)
        denoised[i] = denoised_i

    # Compute LPIPS pairwise between all images
    lpipse = []
    for i in range(denoised.shape[0]):
        for j in range(i+1, denoised.shape[0]):
            lpipse_img = np.zeros((denoised.shape[1],))
            for img in range(denoised.shape[1]):
                lpipse_img[img] = compute_lpips(denoised[i,img], denoised[j,img], device, lpips_fn).view(-1).detach().numpy()[0]
            lpipse.append(lpipse_img)

    lpipse = np.stack(lpipse)
    l_mean = lpipse.mean()
    l_std = lpipse.std(0).mean()
    print(f"----------- mu = {mu}, C = {C} -----------")
    print(f"LPIPS: Mean = {l_mean}, Std = {l_std}")
    
@torch.no_grad()
def compute_lpips(x, y,  device, lpips_fn):
    x = torch.tensor(x).unsqueeze(0)
    y = torch.tensor(y).unsqueeze(0)
    x = x.to(device)
    y = y.to(device)
    y -= y.min()
    y /= y.max()
    x = (x-0.5)*2
    y = (y-0.5)*2
    lpips = lpips_fn(x, y).reshape(-1)
    lpips = lpips.detach().cpu()
    return lpips

if __name__ == "__main__":
    mus = [100, 50, 30, 20, 10, 5, 3, 2, 1]

    device = "cuda" if torch.cuda.is_available() else "cpu"
    lpips_fn = lpips.LPIPS(net="vgg").to(device)

    for mu in mus:
        main(
            mu=mu,
            lpips_fn=lpips_fn,
            device=device
        )