"""
@Description :   提取碎片的边缘，将边缘和碎片的其它元信息存储到 pkl 文件中
@Author      :   tqychy 
@Time        :   2024/12/28 17:39:36
"""
import os
import string

import cv2
import numpy as np
from scripts import data_preprocess


def get_fragment_pairs(matching_set, fragment_images, fragment_transforms, contours, current_nums):
    for i in range(len(fragment_images)):
        for k in range(i+1, len(fragment_images), 1):
            t1, t2 = fragment_transforms[i][:2], fragment_transforms[k][:2]
            contour1, contour2 = contours[i], contours[k]
            contour1 = np.hstack(
                (contour1[:, 1].reshape(-1, 1), contour1[:, 0].reshape(-1, 1)))
            contour2 = np.hstack(
                (contour2[:, 1].reshape(-1, 1), contour2[:, 0].reshape(-1, 1)))
            transformed1 = np.matmul(
                np.hstack((contour1, np.ones((len(contour1), 1)))), t1.T)
            transformed2 = np.matmul(
                np.hstack((contour2, np.ones((len(contour2), 1)))), t2.T)
            # 计算变换后的轮廓边界
            min_x1, min_x2 = transformed1[:, 0].min(), transformed2[:, 0].min()
            max_x1, max_x2 = transformed1[:, 0].max(), transformed2[:, 0].max()
            min_y1, min_y2 = transformed1[:, 1].min(), transformed2[:, 1].min()
            max_y1, max_y2 = transformed1[:, 1].max(), transformed2[:, 1].max()
            # 如果两个碎片不邻近，则不可能匹配
            if (max_x2 - min_x1) * (min_x2 - max_x1) > 100 or (max_y2 - min_y1) * (min_y2 - max_y1) > 100:
                continue
            else:
                idx1, idx2 = \
                    data_preprocess.get_corresbounding(
                        contour1, transformed2, t1)
                if len(idx1) <= 0:
                    continue
                else:
                    matching_set['source_ind'].append(idx1)
                    matching_set['target_ind'].append(idx2)
                    matching_set['GT_pairs'].append(
                        [current_nums + i, current_nums + k])

    return matching_set


def extract_contour(fragment_path: str, matching_set: dict, *args):
    cfg, logger = args
    fragment_names = os.listdir(fragment_path)
    fragment_names = list(filter(lambda x: x[-4:] == '.png', fragment_names))
    fragment_names.sort(key=lambda x: int(x[:-4][9:]))
    transforms = np.zeros((0, 9))

    # 读取碎片转换矩阵
    with open(os.path.join(fragment_path, 'gt.txt'), 'r') as gt_file:
        while True:
            transform = gt_file.readline()
            if not transform:
                break
            else:
                transform = string.capwords(transform.strip()).split(' ')
                if len(transform) == 1:
                    continue
                else:
                    transform = np.asarray(transform, dtype=float)
                    transforms = np.vstack((transforms, transform[:9]))

    transforms = transforms.reshape(-1, 3, 3)
    transforms = np.linalg.inv(transforms)

    # 读取所有碎片和背景
    img_all = []
    belong_image = os.path.basename(fragment_path)
    belong_images = []
    shapes = np.zeros((0, 3), dtype=int)
    current_nums = len(matching_set['full_pcd_all'])
    for _, fragment_name in enumerate(fragment_names):
        if fragment_name[-3:] != 'png':
            continue
        if fragment_name[:8] != 'fragment':
            continue
        img = cv2.imread(os.path.join(
            fragment_path, fragment_name), cv2.IMREAD_UNCHANGED)
        img = img.transpose(1, 0, 2)
        img_all.append(img)
        shapes = np.vstack((shapes, img.shape))
        with open(os.path.join(fragment_path, 'bg.txt'), 'r') as bg_f:
            bg = bg_f.readline()
        bg = np.asarray(bg.split(), dtype=int)

    # 提取碎片轮廓，并对轮廓进行下采样和高斯模糊处理。
    full_contour_all = []
    for i, image in enumerate(img_all):

        logger.debug('fragment {} start'.format(i+1))

        mask = (image == bg).all(axis=-1)
        image[mask] = (0, 0, 0)  # 图片背景设置为0
        img_all[i] = image.transpose(1, 0, 2)
        gray = np.ones(image.shape[:2], dtype=np.uint8)
        gray[~mask] = 255
        _, b_image = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)

        contour, hierarchy = cv2.findContours(
            b_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

        if len(contour) > 1:
            max_len_contour = -1
            max_contour = contour[0]
            for i, c in enumerate(contour):
                if len(c) >= max_len_contour:
                    max_contour = contour[i]
                    max_len_contour = len(c)
            contour = max_contour.reshape(-1, 2)

        else:
            contour = np.asarray(contour, dtype=float).reshape(-1, 2)
        contour_order = contour.copy()

        sigma = 3
        temp_contour = contour_order.copy()
        temp_contour = np.expand_dims(temp_contour, axis=1)
        contour_guss = cv2.GaussianBlur(
            temp_contour.astype(np.float32), (0, 0), sigma)
        contour_guss = np.squeeze(contour_guss)
        contour_guss = contour_guss.astype(int)
        mask = np.linalg.norm(
            contour_guss - np.roll(contour_guss, 1, axis=0), axis=-1) == 0  # 删除重复点
        contour_guss = contour_guss[~mask]

        # 对高斯模糊的轮廓进行方向检查并反转为逆时针方向。
        contour_order_rstep = np.roll(contour_guss, 1, axis=0)
        x_mean_, y_mean_ = contour_guss[:, 0].mean(), contour_guss[:, 1].mean()
        sample_vec = contour_guss - contour_order_rstep
        normal = contour_guss - np.array([x_mean_, y_mean_])
        if np.cross(sample_vec, normal).mean() > 0:
            contour_guss = contour_guss[::-1]
        else:
            pass

        full_contour_all.append(contour_guss)
        belong_images.append(belong_image)

    matching_set['full_pcd_all'].extend(full_contour_all)
    matching_set['img_all'].extend(img_all)
    matching_set['belong_image'].extend(belong_images)
    matching_set['shape_all'].extend(list(shapes))
    matching_set["gt_pose"].extend([transforms[i] for i in range(len(full_contour_all))])

    matching_set = get_fragment_pairs(
        matching_set, img_all, transforms, full_contour_all, current_nums)

    return matching_set


if __name__ == "__main__":
    pass
