import cv2
import numpy as np
import insightface
from insightface.app import FaceAnalysis
# from scipy.interpolate import Rbf  # 移除 Rbf
from skimage.transform import PiecewiseAffineTransform, warp

def get_landmarks(app, img):
    faces = app.get(img)
    if not faces:
        raise ValueError("未检测到人脸")
    print(len(faces[0].landmark_2d_106.astype(np.float32)))
    return faces[0].landmark_2d_106.astype(np.float32)

def tps_warp_image(src_img, src_points, dst_points, dst_shape):
    """
    基于分段仿射变换 (Piecewise Affine Transform) 对 src_img 进行变形，src_points -> dst_points
    src_points 和 dst_points 是 Nx2 的数组
    dst_shape 是目标图像大小 (height, width)
    """
    tform = PiecewiseAffineTransform()
    tform.estimate(dst_points, src_points)  # 注意这里的顺序是 dst -> src，因为 warp 函数是反向映射

    warped = warp(src_img, tform, output_shape=dst_shape, mode='edge')
    warped = (warped * 255).astype(np.uint8)  # 转换回 0-255 范围

    return warped


def main():
    # 初始化 InsightFace
    app = FaceAnalysis(name='buffalo_l', providers=['CUDAExecutionProvider'])
    app.prepare(ctx_id=0, det_size=(256, 256))

    # 读图
    ref = cv2.imread("/data/yangjiarui/diffae/datasets/sr_test/01005/12.png")  # 侧脸源图
    hr = cv2.imread("/data/yangjiarui/diffae/datasets/sr_test/01005/15.png")  # 正脸目标图
    hr = cv2.resize(hr, (ref.shape[1] // 2, ref.shape[0] // 2), interpolation=cv2.INTER_AREA)
    ref = cv2.resize(ref, (ref.shape[1], ref.shape[0]), interpolation=cv2.INTER_AREA)

    # img2_up = cv2.resize(img2_down, (img1.shape[1], img1.shape[0]), interpolation=cv2.INTER_LINEAR)
    # img2 = img2_up
    if hr is None or ref is None:
        raise ValueError("图像路径无效")

    pts1 = get_landmarks(app, ref)
    pts2 = get_landmarks(app, hr)

    # TPS变形
    ref_warped = tps_warp_image(ref, pts1, pts2, (hr.shape[0], hr.shape[1]))

    # 用凸包蒙版做融合
    hull2 = cv2.convexHull(pts2.astype(np.int32))
    mask = np.zeros_like(hr, dtype=np.float32)
    cv2.fillConvexPoly(mask, hull2, (1.0, 1.0, 1.0))

    result = (1-mask) * hr
    result = np.clip(result, 0, 255).astype(np.uint8)

    cv2.imwrite("result.jpg", result)
    print("保存成功: result.jpg")

if __name__ == "__main__":
    main()