import os
import argparse
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim


def load_image(image_path):
    """加载图像并转换为numpy数组"""
    return np.array(Image.open(image_path))


def calculate_ssim(img1, img2):
    """计算两幅图像的SSIM"""
    if img1.shape != img2.shape:
        raise ValueError("图像尺寸不匹配")

    # 处理多通道情况
    if len(img1.shape) == 3:
        channel_axis = 2
    else:
        channel_axis = None

    return ssim(img1, img2,
                data_range=255,
                channel_axis=channel_axis)


def process_directories(base_dir, compare_dir):
    """处理两个根目录的SSIM计算"""
    total_scores = []

    # 遍历base_dir下的所有子目录
    for subdir in os.listdir(base_dir):
        base_samples = os.path.join(base_dir, subdir, 'samples')
        compare_samples = os.path.join(compare_dir, subdir, 'samples')

        if not os.path.exists(compare_samples):
            continue

        # 获取排序后的图片列表
        base_images = sorted([f for f in os.listdir(base_samples) if f.endswith('.png')])
        compare_images = sorted([f for f in os.listdir(compare_samples) if f.endswith('.png')])

        # 确保图片数量一致
        if len(base_images) != len(compare_images):
            raise RuntimeError(f"图片数量不匹配：{subdir}")

        # 计算每对图片的SSIM
        for b_img, c_img in zip(base_images, compare_images):
            img1 = load_image(os.path.join(base_samples, b_img))
            img2 = load_image(os.path.join(compare_samples, c_img))
            total_scores.append(calculate_ssim(img1, img2))

    return np.mean(total_scores) if total_scores else 0


def main(a_dir, b_dir, c_dir):
    """主函数"""
    # 计算a与b的SSIM
    ab_score = process_directories(a_dir, b_dir)
    # 计算a与c的SSIM
    ac_score = process_directories(a_dir, c_dir)

    print(f"SSIM结果：")
    print(f"a vs b: {ab_score:.4f}")
    print(f"a vs c: {ac_score:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--a_dir', default="../runtime/cfgFalse_tokenFalse_numsteps32_maxlength128_bsz1_timestamp20250505_230345", help='原始生成目录')
    parser.add_argument('--b_dir', default="../runtime/cfgTrue_tokenTrue_numsteps32_maxlength128_bsz1_timestamp20250507_232056", help='加速算法目录')
    parser.add_argument('--c_dir', default="../runtime/cfgFalse_tokenFalse_numsteps16_maxlength128_bsz1_timestamp20250506_120840", help='超参数修改目录')
    args = parser.parse_args()

    main(args.a_dir, args.b_dir, args.c_dir)