import numpy as np
import cv2
from insightface.app import FaceAnalysis
import sys
from wan.utils.human_landmark_runner import LandmarkRunner as HumanLandmark
from wan.utils.crop import parse_bbox_from_landmark

class Crop_face:
    def __init__(self):
        providers = ['CPUExecutionProvider']
        self.insightface_app = FaceAnalysis(name="buffalo_l", root='./assets/insightface', providers=providers)
        # self.insightface_app.prepare(ctx_id=int(gpu_index) if gpu_index != 'cpu' else -1,
        #                              det_size=(640, 640))
        self.insightface_app.prepare(ctx_id=-1, det_size=(640, 640))
        self.human_landmark_runner = HumanLandmark(
            ckpt_path='./assets/landmark.onnx',
            onnx_provider='cpu',
            device_id=-1,
        )

    @staticmethod
    def _crop_image_by_bbox(img, bbox, dsize=512):
        left, top, right, bot = bbox
        size = right - left

        src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=np.float32)
        tgt_center = np.array([dsize / 2, dsize / 2], dtype=np.float32)

        s = dsize / size  # scale
        M_o2c = np.array(
            [[s, 0, tgt_center[0] - s * src_center[0]],
             [0, s, tgt_center[1] - s * src_center[1]]],
            dtype=np.float32
        )

        if isinstance(dsize, tuple) or isinstance(dsize, list):
            _dsize = tuple(dsize)
        else:
            _dsize = (dsize, dsize)

        img_crop = cv2.warpAffine(img, M_o2c[:2, :], dsize=_dsize, flags=cv2.INTER_LINEAR)

        return img_crop

    def _get_face(self, ref_image):
        det_res = self.insightface_app.get(ref_image)
        if len(det_res) != 1:
            return None

        face_dict = det_res[0]
        # rect = face_dict['bbox']
        lmk = face_dict.landmark_2d_106
        lmk = self.human_landmark_runner.run(ref_image[:, :, ::-1], lmk)
        ret_bbox = parse_bbox_from_landmark(
            lmk,
            scale=2.2,
            vx_ratio_crop_driving_video=0.0,
            vy_ratio=-0.1,
        )["bbox"]
        bbox = [
            ret_bbox[0, 0],
            ret_bbox[0, 1],
            ret_bbox[2, 0],
            ret_bbox[2, 1],
        ]
        crop_face = self._crop_image_by_bbox(ref_image, bbox)

        return crop_face, bbox


if __name__ == '__main__':
    src_img_path = './crop_aihaop_s0.png'
    src_img = cv2.imread(src_img_path)
    crop_face = Crop_face()
    crop_face_img = crop_face._get_face(src_img)
    crop_face_path = src_img_path.replace('.png', '_face.png')
    cv2.imwrite(crop_face_path, crop_face_img)
