import os
import cv2
import numpy as np
import torch
import lpips
from tqdm import tqdm
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

def evaluate_image_folders(folder_gt, folder_pred):
    # 初始化 LPIPS 模型
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    loss_fn = lpips.LPIPS(net='alex').to(device)

    psnr_values = []
    ssim_values = []
    lpips_values = []

    # 获取文件夹中所有图片名（以 GT 文件夹为准）
    image_names = sorted([
        f for f in os.listdir(folder_gt)
        if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp'))
    ])

    for name in tqdm(image_names, desc="Processing"):
        path_gt = os.path.join(folder_gt, name)
        path_pred = os.path.join(folder_pred, name)

        if not os.path.exists(path_pred):
            print(f"警告：预测文件缺失 {path_pred}")
            continue

        img_gt = cv2.imread(path_gt)
        img_pred = cv2.imread(path_pred)

        if img_gt is None or img_pred is None:
            print(f"警告：无法读取图片 {name}")
            continue

        # 尺寸对齐
        if img_gt.shape != img_pred.shape:
            img_pred = cv2.resize(img_pred, (img_gt.shape[1], img_gt.shape[0]))

        # PSNR
        psnr = peak_signal_noise_ratio(img_gt, img_pred)
        psnr_values.append(psnr)

        # SSIM（使用灰度图）
        gray_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
        gray_pred = cv2.cvtColor(img_pred, cv2.COLOR_BGR2GRAY)
        ssim = structural_similarity(gray_gt, gray_pred)
        ssim_values.append(ssim)

        # LPIPS（RGB, [-1, 1]）
        img1 = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
        img2 = cv2.cvtColor(img_pred, cv2.COLOR_BGR2RGB)

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float() / 127.5 - 1
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float() / 127.5 - 1

        img1 = img1.unsqueeze(0).to(device)
        img2 = img2.unsqueeze(0).to(device)

        with torch.no_grad():
            lpips_value = loss_fn(img1, img2).item()
        lpips_values.append(lpips_value)

    # 平均指标
    results = {
        'PSNR': np.mean(psnr_values) if psnr_values else None,
        'SSIM': np.mean(ssim_values) if ssim_values else None,
        'LPIPS': np.mean(lpips_values) if lpips_values else None,
        'Image_Pairs': len(psnr_values)
    }

    return results

# 使用示例
if __name__ == "__main__":
    folder_gt = "/root/autodl-tmp/TaylorSeer/TaylorSeer-FLUX/bdf2/interval1/5-1/Taylor"
    folder_pred = "/root/autodl-tmp/TaylorSeer/TaylorSeer-FLUX/bdf2/interval5"

    results = evaluate_image_folders(folder_gt, folder_pred)
    print("评估结果：")
    for k, v in results.items():
        print(f"{k}: {v}")
