import os
from glob import glob
from tqdm import tqdm
from PIL import Image
from collections import defaultdict

from loss.pytorch_ssim import ssim
from torchvision.transforms.functional import pil_to_tensor
import lpips
import torch
import torch.nn as nn
from loss.pytorch_ssim import SSIM
from loss.watson_vgg import WatsonDistanceVgg
from main.utils import compute_ssim, compute_psnr, compute_msssim, eval_lpips, eval_psnr_ssim_msssim


device = 'cuda:1'
lpips_metric = lpips.LPIPS(net='vgg').to(device)

source_dir = 'gen_zod'
# source_dir = 'gen_pqim'
source_dir = 'gen_gs'
gt_dir = 'output_images_wo_wm'
source_files = glob(f'{source_dir}/**.png')
gt_files = glob(f'{gt_dir}/**.png')
source_files.sort()
gt_files.sort()

### varify sorting ###
assert len(source_files) == len(gt_files), \
    f'length of source_files and gt_files mismatched, {len(source_files)} != {len(gt_files)}'

for source_file, gt_file in zip(source_files, gt_files):
    source_name = os.path.basename(source_file).split('-')[0]
    gt_name = os.path.basename(gt_file).split('.')[0]
    assert source_name == gt_name, f'source and gt sorting dismatched, {source_name} != {gt_name}'
### varify sorting ###

cache_metrics = defaultdict(list)
for source_file, gt_file in tqdm(zip(source_files, gt_files), total=len(source_files)):
    source_image = Image.open(source_file)
    gt_image = Image.open(gt_file)

    # source_tensor = (pil_to_tensor(source_image) / 255).unsqueeze(0).to(device)
    # gt_tensor = (pil_to_tensor(gt_image) / 255).unsqueeze(0).to(device)

    # ssim_value = ssim(source_tensor, gt_tensor).item()
    # psnr_value = compute_psnr(source_tensor, gt_tensor).item()

    psnr_value, ssim_value, mssim_value = eval_psnr_ssim_msssim(
        gt_file,
        source_file,
        device,
    )

    lpips_value = eval_lpips(gt_file, source_file, lpips_metric, device)

    cache_metrics['psnr'].append(psnr_value)
    cache_metrics['ssim'].append(ssim_value)
    cache_metrics['mssim'].append(mssim_value)
    cache_metrics['lpips'].append(lpips_value)

print('-'*50)
print(f"PSNR = {sum(cache_metrics['psnr']) / len(cache_metrics['psnr'])}")
print(f"SSIM = {sum(cache_metrics['ssim']) / len(cache_metrics['ssim'])}")
print(f"MSSIM = {sum(cache_metrics['mssim']) / len(cache_metrics['mssim'])}")
print(f"LPIPS = {sum(cache_metrics['lpips']) / len(cache_metrics['lpips'])}")
print('-'*50)