"""
@Description :   pairingnet pairing 过程中的输入图像和边缘 encode
@Author      :   tqychy 
@Time        :   2025/01/13 19:07:36
"""
import torch


def pre_encoder1(img, full_pcd, k):
    """
    To convert the contour point coordinate to patches.
    only contour line is considered.
    :param img: [bs, height, width]
    :param full_pcd: [bs, n, 2]
    :param k: patch_size
    :return: pre-encoded contour feature [bs, n, size, size]
    """
    device = img.device
    full_pcd = full_pcd.long()
    bs, fn, _ = full_pcd.size()
    img = img * 0
    bs_idx = torch.arange(0, bs).repeat_interleave(fn).to(device)
    img[bs_idx, full_pcd[:, :, 0].view(-1), full_pcd[:, :, 1].view(-1)] = 1
    template_map = (torch.zeros((k, k)) == 0).nonzero().to(device)
    x_idx, y_idx = template_map[:, 0], template_map[:, 1]
    x_idx, y_idx = x_idx.repeat(bs, fn, 1), y_idx.repeat(bs, fn, 1)  # [bs, n, 225]
    
    # 这里之后可以降采样，跳着采样（k也需要增大）
    x_idx += (full_pcd[:, :, 0].unsqueeze(-1) - k//2)
    y_idx += (full_pcd[:, :, 1].unsqueeze(-1) - k//2)
    bs_idx = torch.arange(0, bs).repeat_interleave(fn * k**2).to(device)
    c = img[bs_idx, x_idx.view(-1), y_idx.view(-1)].view(bs, fn, k, k)

    return c


def pre_encoder2(img, full_pcd, k):
    """
    To convert the contour poin-t coordinate to patches.
    interior + exterior
    :param img: [bs, height, width]
    :param full_pcd: [bs, n, 2]
    :param k: patch_size
    :return: pre-encoded contour feature [bs, n, size, size]
    """
    device = img.device
    full_pcd = full_pcd.long()
    bs, fn, _ = full_pcd.size()
    bs_idx = torch.arange(0, bs).repeat_interleave(fn).to(device)
    img[bs_idx, full_pcd[:, :, 0].view(-1), full_pcd[:, :, 1].view(-1)] = 1
    template_map = (torch.zeros((k, k)) == 0).nonzero().to(device)
    x_idx, y_idx = template_map[:, 0], template_map[:, 1]
    x_idx, y_idx = x_idx.repeat(bs, fn, 1), y_idx.repeat(bs, fn, 1)  # [bs, n, k*k]
    x_idx += (full_pcd[:, :, 0].unsqueeze(-1) - k//2)
    y_idx += (full_pcd[:, :, 1].unsqueeze(-1) - k//2)
    bs_idx = torch.arange(0, bs).repeat_interleave(fn * k**2).to(device)
    c = img[bs_idx, x_idx.view(-1), y_idx.view(-1)].view(bs, fn, k, k)

    return c


def pre_encoder3(img, full_pcd, k):
    """
    To convert the contour point coordinate to patches.
    interior + contour line + exterior
    :param img: tensor [bs, height, width]
    :param full_pcd: tensor [bs, n, 2]
    :param k: int patch_size
    :return: tensor pre-encoded contour feature [bs, n, size, size]
    """
    device = img.device
    full_pcd = full_pcd.long()
    img = img * 2
    bs, fn, _ = full_pcd.size()
    bs_idx = torch.arange(0, bs).repeat_interleave(fn).to(device)
    img[bs_idx, full_pcd[:, :, 0].view(-1), full_pcd[:, :, 1].view(-1)] = 1
    template_map = (torch.zeros((k, k)) == 0).nonzero().to(device)
    x_idx, y_idx = template_map[:, 0], template_map[:, 1]
    x_idx, y_idx = x_idx.repeat(bs, fn, 1), y_idx.repeat(bs, fn, 1)  # [bs, n, k*k]
    x_idx += (full_pcd[:, :, 0].unsqueeze(-1) - k//2)
    y_idx += (full_pcd[:, :, 1].unsqueeze(-1) - k//2)
    bs_idx = torch.arange(0, bs).repeat_interleave(fn * k**2).to(device)
    c = img[bs_idx, x_idx.view(-1), y_idx.view(-1)].view(bs, fn, k, k)

    return c

def img_patch_encoder(img, full_pcd, k): #  img:1,3,331,318  k: patch size
    device = img.device
    full_pcd = full_pcd.long()
    bs, fn, _ = full_pcd.size() # 1，2778，2
    template_map = (torch.zeros((3, k, k)) == 0).nonzero().to(device) #得到所有True位置的索引 [147,3]
    channel, x_idx, y_idx = template_map[:, 0], template_map[:, 1], template_map[:, 2]
    channel, x_idx, y_idx = channel.repeat(bs, fn, 1), x_idx.repeat(bs, fn, 1), y_idx.repeat(bs, fn, 1)  # [bs, n, k*k]
    x_idx += (full_pcd[:, :, 0].unsqueeze(-1) - k//2)
    y_idx += (full_pcd[:, :, 1].unsqueeze(-1) - k//2)
    bs_idx = torch.arange(0, bs).repeat_interleave(fn*3*k**2).to(device) #408366
    c = img[bs_idx, channel.view(-1), x_idx.view(-1), y_idx.view(-1)].view(bs, fn, 3, k, k) # 1,2778,3,7,7

    return c # 形状：[bs, fn（轮廓点的数量）, 3, k, k]