import numpy as np
import scipy.signal
import cv2

def remove_black_background(img):
    """
    将黑色背景设为0，其它保留（适用于RGB或灰度图）
    """
    if img.ndim == 2:  # 灰度图
        mask = img > 10
        result = img.copy()
        result[~mask] = 0
        return result
    else:
        # 非黑色区域 mask（可调阈值）
        mask = np.any(img > 10, axis=-1)
        result = img.copy()
        result[~mask] = 0
        return result
    
def ssim(img0, img1, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
    """计算 SSIM 值"""
    # 调整尺寸
    if img0.shape[:2] != img1.shape[:2]:
        img0 = cv2.resize(img0, (img1.shape[1], img1.shape[0]), interpolation=cv2.INTER_LINEAR)
    
    img0 = remove_black_background(img0)
    img1 = remove_black_background(img1)

    # 处理灰度图
    if img0.ndim == 2:
        img0, img1 = img0[..., None], img1[..., None]
    
    # 高斯滤波
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)
    
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)

    mu0, mu1 = filt_fn(img0), filt_fn(img1)
    sigma00 = filt_fn(img0**2) - mu0**2
    sigma11 = filt_fn(img1**2) - mu1**2
    sigma01 = filt_fn(img0 * img1) - mu0 * mu1

    sigma00, sigma11 = np.maximum(0., sigma00), np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(np.sqrt(sigma00 * sigma11), np.abs(sigma01))

    c1, c2 = (k1 * max_val) ** 2, (k2 * max_val) ** 2
    ssim_map = ((2 * mu0 * mu1 + c1) * (2 * sigma01 + c2)) / ((mu0**2 + mu1**2 + c1) * (sigma00 + sigma11 + c2))
    
    return np.mean(ssim_map)  # 计算整体均值 SSIM

def ssim_interpolation_error(A, B, G30, G60, G90):
    """计算 SSIM 线性插值误差"""
    # 计算 SSIM 值
    SSIM_A_G30 = ssim(A, G30)
    SSIM_A_G60 = ssim(A, G60)
    SSIM_A_G90 = ssim(A, G90)
    
    SSIM_G30_B = ssim(G30, B)
    SSIM_G60_B = ssim(G60, B)
    SSIM_G90_B = ssim(G90, B)

    # 理想 SSIM 计算
    ideal_A = [0.7, 0.4, 0.1]
    ideal_B = [0.3, 0.6, 0.9]
    
    # 计算误差
    error_A = np.mean([(ideal_A[i] - actual) ** 2 for i, actual in enumerate([SSIM_A_G30, SSIM_A_G60, SSIM_A_G90])])
    error_B = np.mean([(ideal_B[i] - actual) ** 2 for i, actual in enumerate([SSIM_G30_B, SSIM_G60_B, SSIM_G90_B])])
    
    total_error = error_A + error_B
    return total_error

# 测试代码
if __name__ == "__main__":
    A = cv2.imread("/workspace/projects/Frosting/metric/pic/diffmorpher/ei/input.png", cv2.IMREAD_COLOR).astype(np.float32)
    B = cv2.imread("/workspace/projects/Frosting/metric/pic/diffmorpher/ei/banana1.png", cv2.IMREAD_COLOR).astype(np.float32)
    G30 = cv2.imread("/workspace/projects/Frosting/metric/pic/diffmorpher/0q0.png", cv2.IMREAD_COLOR).astype(np.float32)
    G60 = cv2.imread("/workspace/projects/Frosting/metric/pic/diffmorpher/04.png", cv2.IMREAD_COLOR).astype(np.float32)
    G90 = cv2.imread("/workspace/projects/Frosting/metric/pic/diffmorpher/0.9.png", cv2.IMREAD_COLOR).astype(np.float32)

    error = ssim_interpolation_error(A, B, G30, G60, G90)
    print(f"SSIM 线性插值误差: {error:.4f}")