import torch
import torch.nn.functional as F
import numpy as np
from rmbg import single_image_rmbg
from  PIL import Image
from tqdm import tqdm

#通过仿射变换匹配掩码
def mask2bound(mask):
    """计算mask的边界框中心点和尺寸（归一化到[-1,1]范围）"""
    H, W = mask.squeeze().shape
    x_inds = torch.nonzero(torch.any(mask.squeeze(), axis=0))
    x_min = ((x_inds[0] + 0.5) / W - 0.5) * 2.0  # 转换到[-1,1]范围
    x_max = ((x_inds[-1] + 0.5) / W - 0.5) * 2.0
    y_inds = torch.nonzero(torch.any(mask.squeeze(), axis=1))
    y_min = ((y_inds[0] + 0.5) / H - 0.5) * 2.0
    y_max = ((y_inds[-1] + 0.5) / H - 0.5) * 2.0
    x_c, y_c = (x_min + x_max) / 2.0, (y_min + y_max) / 2.0  # 中心点
    x_l, y_l = (x_max - x_min), (y_max - y_min)  # 尺寸
    return torch.tensor([x_c, y_c]), torch.tensor([x_l, y_l])


def transformInp(inp, center1, center2, length1, length2):
    """应用仿射变换将输入从空间1变换到空间2"""
    # 创建单位仿射矩阵
    theta = torch.tensor([[1., 0., 0.], [0., 1., 0.]])
    # 生成网格
    grid = F.affine_grid(theta.view(-1, 2, 3), inp.size(), align_corners=False)
    # 计算变换后的网格坐标
    grid = (grid - center2.view(1, 1, 1, 2)) / length2.view(1, 1, 1, 2) * length1.view(1, 1, 1, 2) + center1.view(1, 1,
                                                                                                                  1, 2)
    # 应用网格采样
    return F.grid_sample(inp, grid, align_corners=False)


def align_transform(srcMask, targetMask, optimize_steps=200, lr=1e-4):
    """
    将srcMask对齐到targetMask的仿射变换
    返回:
        transform_params: 变换参数(center1, length1)
    """
    # 确保输入是4D张量
    if len(srcMask.shape) == 2:
        srcMask = srcMask.unsqueeze(0).unsqueeze(0)
    if len(targetMask.shape) == 2:
        targetMask = targetMask.unsqueeze(0).unsqueeze(0)

    # 计算参考mask的边界框(固定不变)
    center_ref, length_ref = mask2bound(targetMask)

    # 初始化可优化参数
    center_opt = center_ref.clone().requires_grad_(True)
    length_opt = length_ref.clone().requires_grad_(True)

    # 优化器
    optimizer = torch.optim.Adam([center_opt, length_opt], lr=lr)

    # 优化循环
    print("Optimizing alignment parameters...")
    for step in tqdm(range(optimize_steps)):
        optimizer.zero_grad()

        # 应用当前变换
        warped_mask = transformInp(srcMask.float(),
                                   center_opt, center_ref,
                                   length_opt, length_ref)

        # 计算损失(负IoU)
        intersection = torch.sum(warped_mask * targetMask.float())
        union = torch.sum(warped_mask + targetMask.float()) - intersection
        iou = intersection / (union + 1e-6)
        loss = -iou  # 最大化IoU = 最小化负IoU

        loss.backward()
        optimizer.step()

    return center_opt.detach(), length_opt.detach()


def apply_alignment(srcImg, targetMask, transform_params):
    """
    应用计算得到的变换参数到srcImg上
    参数:
        srcImg: 要变换的图像 [C,H,W] 或 [1,C,H,W]
        targetMask: 目标mask [1,1,H,W]
        transform_params: (center1, length1)
    返回:
        aligned_img: 对齐后的图像
    """
    if len(srcImg.shape) == 3:
        srcImg = srcImg.unsqueeze(0)  # 添加batch维度

    # 计算refMask的边界框
    center_ref, length_ref = mask2bound(targetMask)

    # 应用变换
    center_render, length_render = transform_params
    aligned_img = transformInp(srcImg.float(),
                               center_render, center_ref,
                               length_render, length_ref)

    return aligned_img.squeeze(0)  # 移除batch维度


def align_to_mask(srcImg, srcImgPath, targetMask):
    srcMaskPil = single_image_rmbg(srcImgPath, None, True)[1].resize((srcImg.size(2), srcImg.size(3)))
    srcMask = pil_to_tensor(srcMaskPil)
    transform_params = align_transform(srcMask, targetMask)
    aligned_img = apply_alignment(srcImg, targetMask, transform_params)

    return aligned_img


def pil_to_tensor(pil_image):
    """将PIL图像转换为张量"""
    if pil_image.mode == 'L':
        tensor_image = torch.from_numpy(np.array(pil_image, dtype=np.float32) / 255.0).unsqueeze(0).unsqueeze(0)
    else:
        tensor_image = torch.from_numpy(np.array(pil_image, dtype=np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
    return tensor_image


def calculate_iou(image1, image2):
    """计算两个图像的IoU"""
    intersection = torch.sum(image1 * image2)
    union = torch.sum(image1 + image2) - intersection
    iou = intersection / (union + 1e-6)
    return iou


# 测试代码
if __name__ == "__main__":
    srcImagePath = "/home/swu/cyh/MasterGraduationProject/application/work2/iterative_editing/images/textureless.png"
    srcMaskPath = "/home/swu/cyh/MasterGraduationProject/application/work2/iterative_editing/images/mask.png"
    refImagePath = "/home/swu/cyh/MasterGraduationProject/application/work2/iterative_editing/images/ref/chair000.png"

    img_res = 1024

    # 读取源图像和掩码
    srcImgPil = Image.open(srcImagePath).convert('RGB').resize((img_res, img_res))
    srcMaskPil = Image.open(srcMaskPath).convert('L').resize((img_res, img_res))
    srcImg = pil_to_tensor(srcImgPil)  # [1,C,H,W]
    srcMask = pil_to_tensor(srcMaskPil)  # [1,1,H,W]

    # 读取参考图像
    refImgMaskPil = refImgMaskPil.resize((img_res, img_res))

    refImg = pil_to_tensor(refImgPil)  # [3,C,H,W]
    refMask = pil_to_tensor(refImgMaskPil)  # [1,1,H,W]

    # refImg = torch.from_numpy(np.array(Image.open(refImagePath).convert('RGB').resize((img_res, img_res)),
    #                                    dtype=np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)  # [1,C,H,W]
    refImgPil, refImgMaskPil = single_image_rmbg(refImagePath, None, True)
    refImgPil = refImgPil.resize((img_res, img_res))
    # 计算对齐变换
    transform_params = align_transform(refMask, srcMask, optimize_steps=200, lr=1e-4)

    print("对齐变换参数:")
    print("中心点:", transform_params[0].numpy())
    print("尺寸:", transform_params[1].numpy())

    # 应用对齐变换
    alignedImg = apply_alignment(refImg, srcMask, transform_params)

    # 计算对齐前后的IoU
    refImgM = (refImg > 0.0).to(torch.float32)
    alignedImgM = (alignedImg > 0.0).to(torch.float32)
    iou_before = calculate_iou(refImgM, srcMask)
    iou_after = calculate_iou(alignedImgM, srcMask)
    print("对齐前IoU:", iou_before.item())
    print("对齐后IoU:", iou_after.item())

    # 将对齐后的图像转换为PIL图像并保存
    alignedImgPil = Image.fromarray((alignedImg.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8))
    alignedImgPil.save("/home/swu/cyh/MasterGraduationProject/application/work2/iterative_editing/images/textureless_aligned.png")







