import os, math
import numpy as np
import scipy.signal
from typing import List, Optional
import cv2
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from skimage.color import deltaE_cie76
import os


def rgb_ssim(img0, img1, max_val,
             filter_size=11,
             filter_sigma=1.5,
             k1=0.01,
             k2=0.03,
             return_map=False):
    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
    assert len(img0.shape) == 3
    assert img0.shape[-1] == 3
    assert img0.shape == img1.shape

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = np.maximum(0., sigma00)
    sigma11 = np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(
        np.sqrt(sigma00 * sigma11), np.abs(sigma01))
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = np.mean(ssim_map)
    return ssim_map if return_map else ssim

__LPIPS__ = {}
def init_lpips(net_name, device):
    assert net_name in ['alex', 'vgg']
    import lpips
    print(f'init_lpips: lpips_{net_name}')
    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)

def rgb_lpips(np_gt, np_im, net_name, device):
    if net_name not in __LPIPS__:
        __LPIPS__[net_name] = init_lpips(net_name, device)
    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
    return __LPIPS__[net_name](gt, im, normalize=True).item()

def compute_edge_smoothness(image_path):
    """
    计算图像的 Edge Smoothness 指标（边缘连续性）。
    
    参数:
        image_path (str): 输入图像路径
        
    返回:
        edge_smoothness (float): 计算的边缘连续性指标（梯度方差）
    """
    # 读取图像并转换为灰度图
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError("无法加载图像，请检查路径！")
    
    # 计算 Laplacian 梯度
    laplacian = cv2.Laplacian(img, cv2.CV_64F)
    
    # 计算梯度的方差
    edge_smoothness = np.var(laplacian)
    
    return edge_smoothness

# 示例测试
# def main():
#     image_path = "morphed_image.png"  # 请替换为你的图像路径
#     smoothness = compute_edge_smoothness(image_path)
#     print(f"Edge Smoothness: {smoothness}")
#     return 0

def compute_folder_edge_smoothness(folder_path):
    """
    计算文件夹内所有图片的 Edge Smoothness，并求平均值。
    
    参数:
        folder_path (str): 包含图像的文件夹路径
        
    返回:
        avg_edge_smoothness (float): 文件夹内所有图片 ES 指标的平均值
    """
    image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}  # 支持的图片格式
    es_values = []

    # 遍历文件夹中的所有图片
    for filename in os.listdir(folder_path):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            image_path = os.path.join(folder_path, filename)
            es_value = check_edge_integrity(image_path)
            es_values.append(es_value)
            print(f"{filename}: Edge Smoothness = {es_value:.4f}")

    # 计算平均 ES
    avg_edge_smoothness = np.mean(es_values) if es_values else 0.0
    print(f"\n文件夹 '{folder_path}' 中所有图片的平均 Edge Smoothness: {avg_edge_smoothness:.4f}")
    return avg_edge_smoothness


def check_edge_integrity(image_path):
    """
    计算边缘完整性，如果断裂区域较多，则可能表示物体破碎
    """
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"无法加载图像: {image_path}")

    # 提取 Canny 边缘
    edges = cv2.Canny(img, 50, 150)

    # 计算边缘像素的连通性（孤立边缘越多，可能破碎）
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edges, connectivity=8)
    
    # 统计孤立边缘数目
    edge_fragments = num_labels - 1  # 减去背景
    print(f"边缘碎片数量: {edge_fragments}")

    return edge_fragments


def calculate_delta_e(image1, image2):
    """
    计算两个图像之间的Delta E值，使用CIE76色差公式
    """
    # 将图像从BGR转为Lab颜色空间
    image1_lab = cv2.cvtColor(image1, cv2.COLOR_BGR2Lab)
    image2_lab = cv2.cvtColor(image2, cv2.COLOR_BGR2Lab)

    # 计算ΔE
    delta_e = deltaE_cie76(image1_lab, image2_lab)
    
    return np.mean(delta_e)

def evaluate_color_consistency(folder_path, source_image, target_image):
    """
    评估每个变换图像与源图像和目标图像的颜色一致性
    :param folder_path: 存放变换图像的文件夹路径
    :param source_image: 源图像
    :param target_image: 目标图像
    """
    # 加载源图像和目标图像
    new_width = 720
    new_height = 720
    image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'} 
    source_img = cv2.imread(source_image)
    target_img = cv2.imread(target_image)
    source_img = cv2.resize(source_img, (new_width, new_height))
    target_img = cv2.resize(target_img, (new_width, new_height))
    
    if source_img is None or target_img is None:
        raise ValueError("无法加载源图像或目标图像")

    # 初始化列表存储每个变换图像的ΔE值
    delta_e_source = []
    delta_e_target = []
    delta_e_diff = []
    
    prev_delta_e_morph = None

    for filename in os.listdir(folder_path):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            morph_img_path = os.path.join(folder_path, filename)
            # 加载变换图像
            morph_img = cv2.imread(morph_img_path)
            new_width = 720
            new_height = 720
            morph_img = cv2.resize(morph_img, (new_width, new_height))
            if morph_img is None:
                continue
        
            # 计算与源图像的ΔE
            de_source = calculate_delta_e(morph_img, source_img)
            delta_e_source.append(de_source)
        
            # 计算与目标图像的ΔE
            de_target = calculate_delta_e(morph_img, target_img)
            delta_e_target.append(de_target)
        
            # 计算与相邻图像的ΔE差异
            if prev_delta_e_morph is not None:
                de_diff = abs(de_source - prev_delta_e_morph)
                delta_e_diff.append(de_diff)
        
            prev_delta_e_morph = de_source
    
    # 计算平均ΔE
    avg_delta_e = (np.mean(delta_e_source) + np.mean(delta_e_target) + np.mean(delta_e_diff))/3
    
    return delta_e_source, delta_e_target, delta_e_diff, avg_delta_e




if __name__ == "__main__":

    ##边缘一致性
    image_path = "/workspace/projects/Frosting/metric/pic/freemorph/dg"  # 请替换为你的图像路径
    smoothness = compute_folder_edge_smoothness(image_path)
    
    
    
    # 颜色

    # folder_path = '/workspace/projects/Frosting/metric/pic/freemorph'
    # source_image = '/workspace/projects/Frosting/metric/pic/freemorph/input/00.png'
    # target_image = '/workspace/projects/Frosting/metric/pic/freemorph/input/06.png'

    # delta_e_source, delta_e_target, delta_e_diff, avg_delta_e = evaluate_color_consistency(folder_path, source_image, target_image)
    # print(f"Average Delta E: {avg_delta_e}")
    
