import os
import torch
from PIL import Image
import cv2
import torchvision
import torchvision.transforms.functional as F
# from torchmetrics import SSIM, LPIPS
from torchmetrics.image import StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm
from transformers import MobileViTForImageClassification, MobileViTImageProcessor, PvtV2ForImageClassification, PvtImageProcessor

import warnings
warnings.filterwarnings('ignore')

# https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/utils/color_util.py#L186
def rgb2ycbcr_pt(img, y_only=False):
    """Convert RGB images to YCbCr images (PyTorch version).

    It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.

    Args:
        img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
         y_only (bool): Whether to only return Y channel. Default: False.

    Returns:
        (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
    """
    if y_only:
        weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
        out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
    else:
        weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
        bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
        out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias

    out_img = out_img / 255.
    return out_img

def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False):
    """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).

    Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Args:
        img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
        img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
        crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: PSNR result.
    """

    assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')

    if crop_border != 0:
        img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
        img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]

    if test_y_channel:
        img = rgb2ycbcr_pt(img, y_only=True)
        img2 = rgb2ycbcr_pt(img2, y_only=True)

    img = img.to(torch.float64)
    img2 = img2.to(torch.float64)

    mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
    return 10. * torch.log10(1. / (mse + 1e-8))

def main(pred_dir, gt_dir):
    H = W = 299

    psnr_val_list = []
    ssim_val_list = []
    lpips_val_list = []

    compute_ssim = StructuralSimilarityIndexMeasure()
    compute_lpips = LearnedPerceptualImagePatchSimilarity()

    for file in tqdm(os.listdir(pred_dir)):
        idx = file.split('.')[0]
        pred_img = F.to_tensor(Image.open(os.path.join(pred_dir, f'{idx}.png')).resize((H, W))).unsqueeze(0)
        gt_img = F.to_tensor(Image.open(os.path.join(gt_dir, f'{idx}.png')).resize((H, W))).unsqueeze(0)

        psnr = calculate_psnr_pt(pred_img, gt_img, 0)
        ssim = compute_ssim(pred_img, gt_img)
        lpips = compute_lpips(
            (pred_img * 2) - 1,
            (gt_img * 2) - 1
        )
        psnr_val_list.append(float(psnr.item()))
        ssim_val_list.append(float(ssim.item()))
        lpips_val_list.append(float(lpips.item()))
    
    
    print("Ave PSNR: ", sum(psnr_val_list) / len(psnr_val_list))
    print("Ave SSIM: ", sum(ssim_val_list) / len(ssim_val_list))
    print("Ave LPIPS: ", sum(lpips_val_list) / len(lpips_val_list))


if __name__ == '__main__':
    # AdvGAN ave PSNR: 40.0386
    GT_DIR = 'third_party/Natural-Color-Fool/dataset/images'
    PRED_DIR = 'temp/1000/inversion'
    main(PRED_DIR, GT_DIR)