from pytorch_msssim import ssim, SSIM
import cv2
import os.path as osp
from EvalData import DiffSample
from torch.utils.data import DataLoader
import torch.nn.functional as F

from tqdm.auto import tqdm
import pathlib
import numpy as np
import torch
import lpips

if __name__ == '__main__':
    val_data =  DiffSample("/val_output", "cvusa_gen_list.txt", "Sat2Density/dataset/CVUSA/sky_mask")


    val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False, pin_memory=True,
                            num_workers=16, drop_last=False)
    dataset_size = len(val_dataloader)    # get the number of images in the dataset.
    print('The number of batches = %d' % dataset_size)
    progress_bar = tqdm(total=len(val_dataloader), disable=False)
    progress_bar.set_description(f"validating:")
    
    mean_psnr = 0
    mean_ssim = 0
    mean_lpips = 0
    mean_lpips_alex = 0
    mean_lpips_squeeze = 0
    
    # ssim_loss = ssim.SSIM(window_size = 11)
    lpips_loss = lpips.LPIPS(net='vgg').cuda()
    lpips_loss.requires_grad_(False)
    lpips_loss_alex = lpips.LPIPS(net='alex').cuda()
    lpips_loss_alex.requires_grad_(False)
    lpips_loss_squeeze = lpips.LPIPS(net='squeeze').cuda()
    lpips_loss_squeeze.requires_grad_(False)
    
    # mask_top = torch.zeros(size=(256, 128))

    for step, batch in enumerate(val_dataloader):
        gt, sample, gt_path, sample_path = batch
        batch_size, _, _, _ = gt.shape
        
        gt_norm = gt * 2. - 1.
        sample_norm = sample * 2. - 1.
        
        gt_masked = gt[:,:,64:, :]
        sample_masked = sample[:,:,64:, :]
        
        gt_norm_masked = gt_norm[:,:,64:, :]
        sample_norm_masked = sample_norm[:, :, 64:, :]


        # mse_loss = F.mse_loss(gt.cuda(), sample.cuda(), reduction="mean")
        mse_loss = F.mse_loss(gt_masked.cuda(), sample_masked.cuda(), reduction="mean")
        psnr_val = 10 * torch.log10(1. / mse_loss)
        mean_psnr += psnr_val
        
           
        # ssim_val = ssim(gt, sample, data_range=1.)
        ssim_val = ssim(sample_masked, gt_masked, data_range=1.)
        mean_ssim += ssim_val
        
        # lpips_val = lpips_loss(gt_norm.cuda(), sample_norm.cuda())
        lpips_val = lpips_loss(gt_norm_masked.cuda(), sample_norm_masked.cuda())
        lpips_val = torch.mean(lpips_val)
        mean_lpips += lpips_val
        # lpips_alex_val = lpips_loss_alex(gt_norm.cuda(), sample_norm.cuda())
        lpips_alex_val = lpips_loss_alex(gt_norm_masked.cuda(), sample_norm_masked.cuda())
        lpips_alex_val = torch.mean(lpips_alex_val)
        mean_lpips_alex += lpips_alex_val
        # lpips_alex_squeeze_val = lpips_loss_squeeze(gt_norm.cuda(), sample_norm.cuda())
        lpips_alex_squeeze_val = lpips_loss_squeeze(gt_norm_masked.cuda(), sample_norm_masked.cuda())
        lpips_alex_squeeze_val = torch.mean(lpips_alex_squeeze_val)
        mean_lpips_squeeze += lpips_alex_squeeze_val
        
        logs = {"psnr": psnr_val.detach().item(), "ssim_val": ssim_val.detach().item(), "lpips_val":lpips_val.detach().item(), "lpips_alex_val": lpips_alex_val.detach().item(), "lpips_squeeze_val": lpips_alex_squeeze_val.detach().item()}
        progress_bar.set_postfix(**logs)
        progress_bar.update(1)
            
    mean_psnr /= len(val_dataloader)
    mean_ssim /= len(val_dataloader)
    mean_lpips /=  len(val_dataloader)
    mean_lpips_alex /= len(val_dataloader)
    mean_lpips_squeeze /= len(val_dataloader)
    
    with open("cvact_masked_eval.txt", "a") as f:
        f.write(f"psnr: {mean_psnr}, ssim_val: {mean_ssim}, lpips_val:{mean_lpips}, lpips_alex_val: {mean_lpips_alex},lpips_squeeze_val: {mean_lpips_squeeze}")