from metrics import Inception, cal_fid, ssim, PSNR, get_feature_images_incep
from lpips import LPIPS

import argparse
import os
import yaml
import logging 
import torch
import numpy as np
from tqdm import tqdm

from torchvision.utils import save_image
from data import define_dataloader
from models import get_model, sampler
from utils import denormalize, set_seed

parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, default="./configs/celeba_inpainting1024.yaml", help="the config files")
parser.add_argument("--hr_shape", type=int, default=128, help="test image size")
parser.add_argument("--ckpt", type=str, default="./results/celeba_aug_fft_mul/models/netG_latest.pth", help="the path of ckpt")
parser.add_argument("--output", default="./test_out/celeba_aug_fft_mul_700/", help="where to store the output")
parser.add_argument("--gpu", type=int, default=0, help="gpu number")
opt = parser.parse_args()

cfg = yaml.load(open(opt.cfg, encoding="utf-8"), Loader=yaml.Loader)
set_seed(cfg["seed"])

# outpus setting -----------------------------------
if opt.output != "":
    cfg["output"] = opt.output
      
img_dir = os.path.join(cfg["output"], "imgs")
os.makedirs(img_dir, exist_ok=True)
# --------------------------------------------------

# logging ------------------------------------------
logging.basicConfig(filename=os.path.join(cfg["output"], "test.log"), level=logging.INFO)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)

# --------------------------------------
device = torch.device(f"cuda:{opt.gpu}")
netG = get_model(cfg["diffusion"]).to(device)
netG.denoise_fn.load_state_dict(torch.load(opt.ckpt, map_location=device))
print(f"load ckpt from: {opt.ckpt}")
netG.set_new_noise_schedule(device, phase="test")
cfg["dataset"]["val"]["config"]["data_len"] = 100
_, val_data = define_dataloader(cfg["dataset"])


# eval ----------------------------------
total_psnr = []
total_ssim = []
all_real = None
all_pred = None
total_lpips = []
flag = 0

incep = Inception().to(device)
incep.eval()
psnr = PSNR(255)
lpips_vgg = LPIPS(net="vgg").to(device)
lpips_vgg.eval()

netG.eval()
with torch.no_grad():
    for item in tqdm(val_data):
        gt = item["image"].to(device)
        con_img = item["con_image"].to(device)
        mask_img = item["mask_image"].to(device)
        mask = item["mask"].to(device)
        name = item["name"]
        # print(netG.gammas)
        # print(netG.betas)
        out, _ = netG.restoration(con_img, y_t=con_img, y_0=con_img, mask=mask, sample_num=1)
        # out = sampler.edm_sampler(netG.denoise_fn, con_img=con_img, mask=mask)
        # out = sampler.ddpm_steps(con_img, con_img, mask, seq=range(0, 1000, 10), model=netG.denoise_fn, b=netG.betas, gammas=netG.gammas)
        # out = out[-1].type(torch.float32).to(device)
        # out = torch.unsqueeze(out[0][-1], dim=0)
        # print(out.shape)
        total_psnr.append(psnr(out, gt).cpu().numpy())
        total_lpips.append(lpips_vgg(out, gt).squeeze().cpu().numpy())
        total_ssim.append(ssim(gt, out).cpu().numpy())
        
        real, pred = get_feature_images_incep(incep, gt, out)
        if flag == 0:
            all_real = real
            all_pred = pred
            flag = 1
        else:
            all_real = np.concatenate([all_real, real], axis=0)
            all_pred = np.concatenate([all_pred, pred], axis=0)
            
        img_grid  = denormalize(torch.cat([mask_img, out, gt], dim=-1))
        save_image(img_grid, os.path.join(img_dir, f"{name[0]}.png"), nrow=1, normalize=False)
        
    fid = cal_fid(all_real, all_pred)
    avg_psnr = sum(total_psnr) / len(total_psnr)
    avg_ssim = sum(total_ssim) / len(total_ssim)
    avg_lpips = sum(total_lpips) / len(total_lpips)
    logging.info(f"fid: {fid}, lpips: {avg_lpips}, psnr: {avg_psnr}, ssim: {avg_ssim}")
    
    
