from metrics import inceptionV3, cal_fid, ssim, PSNR, get_feature_images_incep
from lpips import LPIPS
import torch
import numpy as np
from tqdm import tqdm

def eval_on_cpu(netG, val_data, device):
 
    total_psnr = []
    total_ssim = []
    all_real = None
    all_pred = None
    total_lpips = []
    flag = 0
    
    incep = inceptionV3()
    psnr = PSNR(255)
    lpips_vgg = LPIPS(net="vgg")
    
    # eval ----------------------------
    netG.eval()
    with torch.no_grad():
        for item in tqdm(val_data):
            gt = item["image"].to(device)
            con_img = item["image"].to(device)
            mask_img = item["mask_image"].to(device)
            mask = item["mask_image"].to(device)
            
            out, _ = netG.restoration(con_img, y_t=con_img, y_0=gt, mask=mask, sample_num=1)
            total_psnr.append(psnr(out.cpu(), gt.cpu()).numpy())
            total_lpips.append(lpips_vgg(out.cpu(), gt.cpu()).squeeze().numpy())
            total_ssim.append(ssim(gt.cpu(), out.cpu()).numpy())
            
            real, pred = get_feature_images_incep(incep, gt.cpu(), out.cpu())
            
            if flag == 0:
                all_real = real
                all_pred = pred
                flag = 1
            else:
                all_real = np.concatenate([all_real, real], axis=0)
                all_pred = np.concatenate([all_pred, pred], axis=0)
        fid = cal_fid(all_real, all_pred)
        avg_psnr = sum(total_psnr) / len(total_psnr)
        avg_ssim = sum(total_ssim) / len(total_ssim)
        avg_lpips = sum(total_lpips) / len(total_lpips)
     # -----------------------------------------
        
        metric = {}
        metric["fid"] = fid
        metric["psnr"] = avg_psnr
        metric["ssim"] = avg_ssim
        metric["lpips"] = avg_lpips
        
        return metric   
        
                                     
def eval_on_cuda(netG, val_data, device):

    total_psnr = []
    total_ssim = []
    all_real = None
    all_pred = None
    total_lpips = []
    flag = 0
    
    incep = inceptionV3().to(device)
    psnr = PSNR(255)
    lpips_vgg = LPIPS(net="vgg").to(device)
    
    # eval ----------------------------
    netG.eval()
    with torch.no_grad():
        for item in tqdm(val_data):
            gt = item["image"].to(device)
            con_img = item["image"].to(device)
            mask_img = item["mask_image"].to(device)
            mask = item["mask_image"].to(device)
            
            out, _ = netG.restoration(con_img, y_t=con_img, y_0=gt, mask=mask, sample_num=1)
            total_psnr.append(psnr(out, gt).cpu().numpy())
            total_lpips.append(lpips_vgg(out, gt).squeeze().cpu().numpy())
            total_ssim.append(ssim(gt, out).cpu().numpy())
            
            real, pred = get_feature_images_incep(incep, gt, out)
            
            if all_real == None:
                all_real = real
                all_pred = pred
                flag = 1
            else:
                all_real = np.concatenate([all_real, real], axis=0)
                all_pred = np.concatenate([all_pred, pred], axis=0)
        fid = cal_fid(all_real, all_pred)
        avg_psnr = sum(total_psnr) / len(total_psnr)
        avg_ssim = sum(total_ssim) / len(total_ssim)
        avg_lpips = sum(total_lpips) / len(total_lpips)
     # -----------------------------------------
        
        metric = {}
        metric["fid"] = fid
        metric["psnr"] = avg_psnr
        metric["ssim"] = avg_ssim
        metric["lpips"] = avg_lpips
        
        return metric   
    
def guidetsr_eval(model, val_data, device):
    
    total_psnr = []
    total_ssim = []

    psnr = PSNR(255)

    with torch.no_grad():
        for item in tqdm(val_data):
            lq = item["lq"].to(device)
            gt = item["gt"].to(device)
            
            out = model(lq)
            
            total_psnr.append(psnr(out, gt).cpu().numpy())
            total_ssim.append(ssim(gt, out).cpu().numpy())
            
            # img_grid = denormalize(torch.cat([diffusion.SR, gt], -1))
            # save_image(img_grid, os.path.join(img_dir, item['fname'][0].replace(".jpg", ".png")), nrow=1, normalize=False)

    avg_psnr = sum(total_psnr) / len(total_psnr)
    avg_ssim = sum(total_ssim) / len(total_ssim)
    
    return avg_psnr, avg_ssim

    
    

    