'''
for continuous drag
2024-02-26
'''

import torch
import copy
from tqdm import tqdm
import numpy as np
pdist = torch.nn.PairwiseDistance(p=2)

def draw_heatmap(data,background_image_path, save_path = './heat_map_nobg_v3.png', no_bg=0):
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from skimage import io, img_as_float
    from matplotlib import cm
    from skimage import transform
    from numpy import fliplr, flipud
    from scipy.interpolate import interp2d
    # 读取背景图
    # background_image_path = '/home/gisp3/Additional_Disk_4T_1/zxj/FastDrag_v1_interpolation/utils/chart/000人像实验结果/cli/2024-04-30-1615-47_cli.png'  # 替换为你的图像文件路径
    background_image = io.imread(background_image_path)
    background_image = img_as_float(background_image)  # 转换为浮点数
    background_image = transform.rotate(background_image, 180)
    background_image = fliplr(background_image)
    # 生成图像的RGBA值
    background_image_rgba = cm.get_cmap('gray')(background_image[:, :, 0])  # 取灰度图并转换为RGBA

    # 模拟的3D数据
    # data = np.random.rand(512, 512)  # 10x10网格，值在0到1之间
    # 插值扩展数据到512x512
    x_old = np.linspace(0, 63, 64)
    y_old = np.linspace(0, 63, 64)
    interpolator = interp2d(x_old, y_old, data, kind='cubic')
    x_new = np.linspace(0, 63, 512)
    y_new = np.linspace(0, 63, 512)
    data = interpolator(x_new, y_new)


    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # 创建x, y坐标网格
    x = np.linspace(0, background_image.shape[1], data.shape[1])
    y = np.linspace(0, background_image.shape[0], data.shape[0])
    x, y = np.meshgrid(x, y)

    # 设置图像在较低的z平面
    if no_bg==0:
        z_offset = -0.1 * np.ones(data.shape)
        ax.plot_surface(x, y, z_offset, facecolors=background_image_rgba, shade=False)

    # 绘制3D热力图
    surf = ax.plot_surface(x, y, data, cmap='coolwarm', edgecolor='none', alpha=0.75)

    # 设置图表旋转
    ax.view_init(elev=25, azim=15)  # Elevate and rotate

    # 添加颜色条
    cbar = fig.colorbar(surf)
    cbar.set_label('Value')

    # 隐藏网格线
    ax.grid(False)
    ax.xaxis.pane.fill = False  # 去掉背景色
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    ax.xaxis.pane.set_edgecolor('w')  # 设置边框颜色为白色（与背景融合）
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')

    # 隐藏坐标轴
    ax.set_axis_off()


    # 设置轴标签
    ax.set_xlabel('X Coordinate')
    ax.set_ylabel('Y Coordinate')
    ax.set_zlabel('Value')

    plt.title('3D Heat Map with Background Image')
    plt.savefig(save_path)  # 保存高斯噪声图像
    # plt.show()

def draw_heatmap_nobg(data):
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    # 模拟的数据，假设每个点的值表示在该点的高度或者温度
    # data = np.random.rand(10, 10)  # 10x10的网格，值在0到1之间

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # 创建x, y坐标网格
    x = np.arange(data.shape[1])
    y = np.arange(data.shape[0])
    x, y = np.meshgrid(x, y)

    # 绘制3D热力图
    surf = ax.plot_surface(x, y, data, cmap='coolwarm')

    # 隐藏网格线
    ax.grid(False)
    ax.xaxis.pane.fill = False  # 去掉背景色
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    ax.xaxis.pane.set_edgecolor('w')  # 设置边框颜色为白色（与背景融合）
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')

    # 隐藏坐标轴
    ax.set_axis_off()

    # 添加颜色条
    cbar = fig.colorbar(surf)
    cbar.set_label('Value')

    # 设置轴标签
    ax.set_xlabel('X Coordinate')
    ax.set_ylabel('Y Coordinate')
    ax.set_zlabel('Value')

    plt.title('3D Heat Map')
    plt.savefig('./heat_map_nobg.png')  # 保存高斯噪声图像
    # plt.show()

def judge_edge(point_tuple,invert_code_d):
    y,x = point_tuple[0],point_tuple[1]
    max_y,max_x = invert_code_d.shape[2],invert_code_d.shape[3]
    y = 0 if y<0 else y
    x = 0 if x<0 else x
    y = int(max_y-1) if y>max_y-1 else y
    x = int(max_x-1) if x>max_x-1 else x
    new_point_tuple = (y,x)
    return new_point_tuple

def draw_arrow_graph(points,vectors,save_path='./arrow_chart_v1.png'):
    import matplotlib.pyplot as plt
    import numpy as np

    # points: 起点位置数据
    # vectors：移动向量数据

    # 绘制原始点
    plt.scatter(points[:, 0], points[:, 1], color='red', label='Original Points')

    # 绘制向量
    for point, vector in zip(points, vectors):
        plt.arrow(point[0], point[1], vector[0], vector[1], head_width=1, head_length=2, fc='red', ec='black')
    # 绘制移动后的点
    moved_points = points + vectors
    print(moved_points[:, 0])
    print(moved_points[:, 1])
    plt.scatter(moved_points[:, 0], moved_points[:, 1], color='orange', label='Moved Points')
    # 设置图表标题和图例
    plt.title('Points Movement Visualization')
    plt.xlabel('X axis')
    plt.ylabel('Y axis')
    plt.legend()
    plt.axis('equal')  # 设置坐标轴的比例相同
    plt.grid(False)  # 显示网格
    plt.savefig(save_path)  # 保存到文件 

def draw_arrow_graph_with_bg(points,vectors,img_path):
    import matplotlib.pyplot as plt
    import numpy as np

    # points: 起点位置数据
    # vectors：移动向量数据
    
    background = plt.imread(img_path)
    fig, ax = plt.subplots()    # 创建一个图和轴
    # 设置背景图像
    # ax.imshow(background, cmap='gray', extent=(0, 72, 0, 72))  # extent调整图像的坐标轴范围
    ax.scatter(points[:, 0], points[:, 1], color='red') # 画原始点
    for point, vector in zip(points, vectors):
        ax.arrow(point[0], point[1], vector[0], vector[1], head_width=1, head_length=2, fc='blue', ec='black')
    ax.axis('off')
    plt.grid(False)  # 显示网格
    plt.savefig('./arrow_chart_v2.png')  # 保存到文件 

def get_rectangle(mask: torch.Tensor):
    # get the rect of 1 in mask
    N,X,H,W = mask.shape  # eg:torch.Size([1, 1, 64, 64])   mask_cp_handle
    index_1 = torch.nonzero(mask)   # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])   n个1元素
    min_y,min_x = torch.min(index_1,dim=0)[0][-2:]
    max_y,max_x = torch.max(index_1,dim=0)[0][-2:]
    # left_top = (min_x, min_y)
    # left_bottom = (min_x, max_y)
    # right_top = (max_x, min_y)
    # right_bottom = (max_x, max_y)
    # rect = torch.Tensor((left_top, left_bottom, right_top, right_bottom),device=mask.device)
    left_top = torch.Tensor((min_y, min_x)).to(device=mask.device)
    left_bottom = torch.Tensor((min_y, max_x)).to(device=mask.device)
    right_top = torch.Tensor((max_y, min_x)).to(device=mask.device)
    right_bottom = torch.Tensor((max_y, max_x)).to(device=mask.device)
    rect = torch.stack((left_top, left_bottom, right_top, right_bottom),dim=0).to(device=mask.device)
    return rect, left_top, left_bottom, right_top, right_bottom
    
def interpolation(x):
    assert x.dim() == 4, "Input tensor x should have shape (1, C, N, M)"
    batch_size, channels, N, M = x.shape # batch_size 恒=1  channels 恒=4

    # 处理batch中的每张图片（在这个例子中只有一张）
    for b in range(batch_size):
        # 我们只需要检查一个通道（这里检查第一个通道）来找到所有0值的位置
        zero_positions = (x[b, 0] == 0)

        # 对于找到的每个需要插值的点
        for i in range(N):
            for j in range(M):
                if zero_positions[i, j]:
                    # 初始化距离和值
                    values = []  # 用于存储非0邻居的值
                    weights = []  # 用于存储基于距离的权重

                    # 在四个方向寻找最近的非零值
                    # 左侧
                    for k in range(1, j + 1):
                        if j - k >= 0 and x[b, 0, i, j - k] != 0:
                            values.append(x[b, :, i, j - k])
                            weights.append(1 / k)
                            break

                    # 右侧
                    for k in range(1, M - j):
                        if j + k < M and x[b, 0, i, j + k] != 0:
                            values.append(x[b, :, i, j + k])
                            weights.append(1 / k)
                            break

                    # 上方
                    for k in range(1, i + 1):
                        if i - k >= 0 and x[b, 0, i - k, j] != 0:
                            values.append(x[b, :, i - k, j])
                            weights.append(1 / k)
                            break

                    # 下方
                    for k in range(1, N - i):
                        if i + k < N and x[b, 0, i + k, j] != 0:
                            values.append(x[b, :, i + k, j])
                            weights.append(1 / k)
                            break

                    # 如果有找到非零邻居，则进行加权平均计算插值
                    if weights:
                        total_weight = sum(weights)
                        interpolated_value = sum(w * v for w, v in zip(weights, values)) / total_weight
                        x[b, :, i, j] = interpolated_value

    return x

def get_circle(mask: torch.Tensor):
    rect, left_top, left_bottom, right_top, right_bottom = get_rectangle(mask=mask)
    center = torch.Tensor(((left_top[0] + right_bottom[0]) / 2, (left_top[1] + right_bottom[1]) / 2)).to(device=mask.device)  # y,x
    radius = pdist(center, left_top)   # 计算欧式距离
    return center,radius


def get_scale_factor(C, A, OA, d_OA, R, O):
    '''
    xA, yA = A  xB, yB = B  xC, yC = C
    尝试是否可以同时作用于所有点
    '''
    # print("\n=============================")
    AC =  C-A  # torch.Tensor([yC - yA, xC - xA]) 
    d_AC = torch.norm(AC)
    e_AC = AC/d_AC
    # print(f"O:{O}   \nA:{A} \nC:{C}  \nAC:{AC} \nd_AC:{d_AC} \ne_AC:{e_AC}  \nOA:{OA} \ntorch.dot(AC, OA):{torch.dot(AC, OA)}")
    L0 = torch.dot(AC, OA) / d_AC             #  |GA|    θ>90，L0<0
    L1 = torch.sqrt(R**2 - d_OA**2 + L0**2)   #  |GP|
    AP = (L1-L0)*e_AC  # GP-GA GP = L*t_AC GA = L0*t_AC
    PC = AC-AP
    # print(f"L0:{L0} \nL1:{L1}  \nAP:{AP} \nPC:{PC}")
    scale_factor = torch.norm(PC)/torch.norm(AP)  # |PC|/|AP| == |CD|/|AB|
    # print(f"scale_factor:{scale_factor}")
    return scale_factor

def transform_point(point, shift_yx, scale_factor):
    shift_yx = shift_yx*scale_factor
    point_new = torch.round(point+shift_yx)
    # print(f'scale_factor:{scale_factor}  shift_yx:{shift_yx}  point:{point}  point_new:{point_new}')
    return point_new


# 多对点drag，插值
def drag_stretch_multipoint_ratio_interp(invert_code,handle_points,target_points,mask_cp_handle,shift_yx=None,fill_mode='interpolation'):
    # print('mask_cp_handle:',mask_cp_handle.shape)   # torch.Size([1, 1, 64, 64])
    invert_code_d = copy.deepcopy(invert_code)
    if fill_mode == 'ori':
        print("mask to ori")
    if fill_mode == '0':
        print("mask to 0")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = 0 # 使mask内的值为0
    if fill_mode == "interpolation":
        print("mask to interpolation")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = 0 # 使mask内的值为0
    if fill_mode == "random":
        print("random")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = torch.rand_like(invert_code_d)[(mask_cp_handle>0).repeat(1,4,1,1)].to(device=invert_code_d.device)


    index_1 = torch.nonzero(mask_cp_handle) # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])
    O,R = get_circle(mask_cp_handle)       # y,x
    move_vectors = []       # 为嵌套列表，长度为mask内需要移动的点的数量，内列表长度为操作点对的数量
    move_vectors_radio = [] # 形状与move_vectors相同，存储每个点到每个handle点的距离，用于基于多对点的移动比例
    for point_i in range(len(handle_points)):
        print(f"point rate: {point_i+1}/{len(handle_points)}")
        A = handle_points[point_i].to(device=mask_cp_handle.device)/4    # y,x
        B = target_points[point_i].to(device=mask_cp_handle.device)/4
        shift_yx = B-A
        OA =  A-O  # torch.Tensor([O[0] - A[0], O[1] - A[1]])
        d_OA = torch.norm(OA)  # &
        # print(f"A:{A}   \nO:{O}     \nOA:{OA}   \nd_OA:{d_OA}")
        for j, index in enumerate(tqdm(index_1, desc="get factor")):
            C = index[-2:]      # y,x
            scale_factor = get_scale_factor(C, A, OA, d_OA, R, O)
            move_vector = scale_factor*shift_yx
            if len(move_vectors)<=j:
                move_vectors.append([move_vector,])
                move_vectors_radio.append([1/(torch.norm(C-A)+0.0001),])
            else:
                move_vectors[j].append(move_vector)
                move_vectors_radio[j].append(1/(torch.norm(C-A)+0.0001))
    
    for j, index in enumerate(index_1):
        move_vectors[j] = torch.cat([ts.unsqueeze(0) for ts in move_vectors[j]], dim=0)
        move_vectors_radio[j] = torch.cat([ts.unsqueeze(0) for ts in move_vectors_radio[j]], dim=0)
    # print("move_vectors: ", move_vectors_radio)
    # print("move_vectorssss: ", move_vectors_radio[0], move_vectors_radio[0].sum(), move_vectors_radio[0]/move_vectors_radio[0].sum() )
    # print("move_vectors: ", move_vectors)
    # print("move_vectorssss: ", move_vectors[0], move_vectors[0].sum(), move_vectors[0]/move_vectors[0].sum() )
    move_mode = "not recover"
    point_new_l = []
    point_new_l_value = {}
    graph_points = []
    graph_vectors = []
    heatmap_value = np.zeros((64,64))+20
    heatmap_value_target = np.zeros((64,64))+20
    flag = 0
    print("point_new_l,point_new_l",point_new_l,point_new_l)
    for j, index in enumerate(tqdm(index_1, desc="drag stretch")):
        C = index[-2:]
        radio_factor = move_vectors_radio[j]/move_vectors_radio[j].sum()
        # move_vector = (radio_factor*move_vectors[j]).sum(dim=0)
        move_vector = (radio_factor*move_vectors[j].T).T.sum(dim=0)
        # print("move_vector: ", move_vector)
        point_new = torch.round(C+move_vector)
        # for draw arrow chart

        if flag%20 == 0:
            try:
                graph_points.append([int(torch.round(C[1])),int(torch.round(C[0]))])
                graph_vectors.append([int(torch.round(move_vector[1])),int(torch.round(move_vector[0]))])
                flag+=1
            except Exception as e:
                print(f"has a err: {e}")
                print(f"move_vector: {move_vector} C: {C} radio_factor: {radio_factor}  move_vectors_radio[j]: {move_vectors_radio[j]}")
        else:
            flag+=1
        
        try:
            point_tuple = (int(torch.round(point_new[0])),int(torch.round(point_new[1])))
            point_tuple = judge_edge(point_tuple,invert_code_d)
        except Exception as e:
            print(f"has a err: {e}")
            print(f"point_new: {point_new} C: {C} move_vector: {move_vector}")
        
        # for drag heatmap
        # try:
        #     heatmap_value[C[0],C[1]] += int(torch.round(torch.norm(move_vector))) 
        #     heatmap_value_target[point_tuple[0],point_tuple[1]] += int(torch.round(torch.norm(move_vector)))
        # except Exception as e:
        #     print(f"has a err: {e}")
            

        if move_mode == "not recover": # not recover point which has been cover
            if point_tuple in point_new_l:
                continue    # 判断，如果这个点已经被覆盖过了，就不覆盖了
            point_new_l.append(point_tuple)
            invert_code_d[:,:,point_tuple[0],point_tuple[1]] = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]
        elif move_mode == "mean when recover": # when point is recovered, set the mean value to this point
            move_value = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]
            if point_tuple not in point_new_l_value.keys(): # 将移动的值都存起来
                point_new_l_value[point_tuple] = [move_value]
            else:
                point_new_l_value[point_tuple].append(move_value)

            if point_tuple in point_new_l:# 判断，如果这个点已经被覆盖过了，则求均值
                invert_code_d[:,:,point_tuple[0],point_tuple[1]] = sum(point_new_l_value[point_tuple])/len(point_new_l_value[point_tuple])
                continue
            point_new_l.append(point_tuple)
            invert_code_d[:,:,point_tuple[0],point_tuple[1]] = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]  

    print("point_new_l: \n", len(point_new_l))
    print("point_new_l set: \n", len(set(point_new_l)))
    # draw_heatmap_nobg(heatmap_value)
    # draw_heatmap(heatmap_value, "/home/gisp3/Additional_Disk_4T_1/zxj/FastDrag_v1_interpolation/utils/chart/dog_result/ori/2024-05-09-1609-47_ori.png",save_path="./heat_map_nobg_v5.png",no_bg=1)
    # draw_heatmap(heatmap_value_target, "/home/gisp3/Additional_Disk_4T_1/zxj/FastDrag_v1_interpolation/utils/chart/dog_result/ori/2024-05-09-1609-47_ori.png",save_path="./heat_map_nobg_v6.png",no_bg=1)

    # draw_heatmap(heatmap_value, "/home/gisp3/Additional_Disk_4T_1/zxj/FastDrag_v1_interpolation/utils/chart/000人像实验结果/cli/2024-04-30-1615-47_cli.png")
    # draw_arrow_graph_with_bg(np.array(graph_points),np.array(graph_vectors), "/home/gisp3/Additional_Disk_4T_1/zxj/FastDrag_v1_interpolation/utils/chart/dog_result/ori/2024-05-09-1609-47_ori.png")

    # draw_arrow_graph(np.array(graph_points),np.array(graph_vectors),save_path='./arrow_chart_dog.png')
    # point_new_l.sort()
    # print(point_new_l)
    if fill_mode == "interpolation":
        invert_code_d = interpolation(invert_code_d)
    return invert_code_d  


# 多对点drag
def drag_stretch_multipoint_ratio(invert_code,handle_points,target_points,mask_cp_handle,shift_yx=None,fill_mode='ori'):
    # print('mask_cp_handle:',mask_cp_handle.shape)   # torch.Size([1, 1, 64, 64])
    invert_code_d = copy.deepcopy(invert_code)
    if fill_mode == 'ori':
        print("mask to ori")
    if fill_mode == '0':
        print("mask to 0")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = 0 # 使mask内的值为0

    index_1 = torch.nonzero(mask_cp_handle) # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])
    O,R = get_circle(mask_cp_handle)       # y,x
    move_vectors = []       # 为嵌套列表，长度为mask内需要移动的点的数量，内列表长度为操作点对的数量
    move_vectors_radio = [] # 形状与move_vectors相同，存储每个点到每个handle点的距离，用于基于多对点的移动比例
    for point_i in range(len(handle_points)):
        print(f"point rate: {point_i+1}/{len(handle_points)}")
        A = handle_points[point_i].to(device=mask_cp_handle.device)/4    # y,x
        B = target_points[point_i].to(device=mask_cp_handle.device)/4
        shift_yx = B-A
        OA =  A-O  # torch.Tensor([O[0] - A[0], O[1] - A[1]])
        d_OA = torch.norm(OA)  # &
        # print(f"A:{A}   \nO:{O}     \nOA:{OA}   \nd_OA:{d_OA}")
        for j, index in enumerate(tqdm(index_1, desc="get factor")):
            C = index[-2:]      # y,x
            scale_factor = get_scale_factor(C, A, OA, d_OA, R, O)
            move_vector = scale_factor*shift_yx
            if len(move_vectors)<=j:
                move_vectors.append([move_vector,])
                move_vectors_radio.append([1/(torch.norm(C-A)+0.0001),])
            else:
                move_vectors[j].append(move_vector)
                move_vectors_radio[j].append(1/(torch.norm(C-A)+0.0001))
    
    for j, index in enumerate(index_1):
        move_vectors[j] = torch.cat([ts.unsqueeze(0) for ts in move_vectors[j]], dim=0)
        move_vectors_radio[j] = torch.cat([ts.unsqueeze(0) for ts in move_vectors_radio[j]], dim=0)
    # print("move_vectors: ", move_vectors_radio)
    # print("move_vectorssss: ", move_vectors_radio[0], move_vectors_radio[0].sum(), move_vectors_radio[0]/move_vectors_radio[0].sum() )
    # print("move_vectors: ", move_vectors)
    # print("move_vectorssss: ", move_vectors[0], move_vectors[0].sum(), move_vectors[0]/move_vectors[0].sum() )
    for j, index in enumerate(tqdm(index_1, desc="drag stretch")):
        C = index[-2:]
        radio_factor = move_vectors_radio[j]/move_vectors_radio[j].sum()
        # move_vector = (radio_factor*move_vectors[j]).sum(dim=0)
        move_vector = (radio_factor*move_vectors[j].T).T.sum(dim=0)
        # print("move_vector: ", move_vector)
        point_new = torch.round(C+move_vector)
        try:
            invert_code_d[:,:,int(torch.round(point_new[0])),int(torch.round(point_new[1]))] = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]
        except Exception as e:
            print(f"has a err: {e}")
            print(f"index: {index} \move_vector: {move_vector} \nscale_factor: {scale_factor} \nshift_yx: {shift_yx}")
    return invert_code_d  


def drag_stretch_multipoint(invert_code,handle_points,target_points,mask_cp_handle,shift_yx=None,fill_mode='ori'):
    # print('mask_cp_handle:',mask_cp_handle.shape)   # torch.Size([1, 1, 64, 64])
    invert_code_d = copy.deepcopy(invert_code)
    if fill_mode == 'ori':
        print("mask to ori")
    if fill_mode == '0':
        print("mask to 0")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = 0 # 使mask内的值为0

    index_1 = torch.nonzero(mask_cp_handle) # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])
    O,R = get_circle(mask_cp_handle)       # y,x
    move_vectors = []       # 为嵌套列表，长度为mask内需要移动的点的数量，内列表长度为操作点对的数量
    for point_i in range(len(handle_points)):
        A = handle_points[point_i].to(device=mask_cp_handle.device)/4    # y,x
        B = target_points[point_i].to(device=mask_cp_handle.device)/4
        shift_yx = B-A
        OA =  A-O  # torch.Tensor([O[0] - A[0], O[1] - A[1]])
        d_OA = torch.norm(OA)  # &
        # print(f"A:{A}   \nO:{O}     \nOA:{OA}   \nd_OA:{d_OA}")
        for j, index in enumerate(tqdm(index_1, desc="drag stretch")):
            C = index[-2:]      # y,x
            scale_factor = get_scale_factor(C, A, OA, d_OA, R, O)
            move_vector = scale_factor*shift_yx
            # print(f"B:{B} A:{A}   |  {scale_factor} * {shift_yx} = {move_vector}\n")
            if len(move_vectors)<=j:
                move_vectors.append([move_vector,])
            else:
                move_vectors[j].append(move_vector)
    
    for j, index in enumerate(tqdm(index_1, desc="drag stretch")):
        C = index[-2:]
        move_vector = sum(move_vectors[j])
        # print("move_vector: ", move_vector)
        point_new = torch.round(C+move_vector)
        try:
            invert_code_d[:,:,int(torch.round(point_new[0])),int(torch.round(point_new[1]))] = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]
        except Exception as e:
            print(f"has a err: {e}")
            print(f"index: {index} \move_vector: {move_vector} \nscale_factor: {scale_factor} \nshift_yx: {shift_yx}")
    return invert_code_d  


# 单对点drag
def drag_stretch(invert_code,handle_points,target_points,mask_cp_handle,shift_yx,fill_mode='ori'):
    # print('mask_cp_handle:',mask_cp_handle.shape)   # torch.Size([1, 1, 64, 64])
    invert_code_d = copy.deepcopy(invert_code)
    if fill_mode == 'ori':
        print("mask to ori")
    if fill_mode == '0':
        print("mask to 0")
        invert_code_d[(mask_cp_handle>0).repeat(1,4,1,1)] = 0 # 使mask内的值为0
    A = handle_points[0].to(device=mask_cp_handle.device)/4    # y,x
    B = target_points[0].to(device=mask_cp_handle.device)/4
    O,R = get_circle(mask_cp_handle)       # y,x
    OA =  A-O  # torch.Tensor([O[0] - A[0], O[1] - A[1]])
    d_OA = torch.norm(OA)  # &
    # print(f"A:{A}   \nO:{O}     \nOA:{OA}   \nd_OA:{d_OA}")
    index_1 = torch.nonzero(mask_cp_handle) # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])
    for index in tqdm(index_1, desc="drag stretch"):
        C = index[-2:]      # y,x
        scale_factor = get_scale_factor(C, A, OA, d_OA, R, O)
        C_new = transform_point(point=C, shift_yx=shift_yx/4, scale_factor=scale_factor)    # C_new是否会超过边界？
        try:
            invert_code_d[:,:,int(torch.round(C_new[0])),int(torch.round(C_new[1]))] = invert_code[:,:,int(torch.round(C[0])),int(torch.round(C[1]))]
        except Exception as e:
            print(f"has a err: {e}")
            print(f"index: {index} \nC_new: {C_new} \nscale_factor: {scale_factor} \nshift_yx: {shift_yx}")
    return invert_code_d

def drag_stretch_patch(invert_code,handle_points,target_points,mask_cp_handle,shift_yx,half_patch_size=3):
    # print('mask_cp_handle:',mask_cp_handle.shape)   # torch.Size([1, 1, 64, 64])
    invert_code_d = copy.deepcopy(invert_code)
    A = handle_points[0].to(device=mask_cp_handle.device)/4    # y,x
    B = target_points[0].to(device=mask_cp_handle.device)/4
    O,R = get_circle(mask_cp_handle)       # y,x
    OA =  A-O  # torch.Tensor([O[0] - A[0], O[1] - A[1]])
    d_OA = torch.norm(OA)  # &
    # print(f"A:{A}   \nO:{O}     \nOA:{OA}   \nd_OA:{d_OA}")
    index_1 = torch.nonzero(mask_cp_handle) # 所有1元素的索引  index_1 = tensor([[0,0,y1,x1], [0,0,y2,x2], ..., [0,0,yn,xn]])
    for index in tqdm(index_1, desc="drag stretch"):
        C = index[-2:]      # y,x
        scale_factor = get_scale_factor(C, A, OA, d_OA, R, O)
        C_new = transform_point(point=C, shift_yx=shift_yx/4, scale_factor=scale_factor)    # C_new是否会超过边界？
        invert_code_d[:,:,
                      int(torch.round(C_new[0]))-half_patch_size:int(torch.round(C_new[0]))+half_patch_size,
                      int(torch.round(C_new[1]))-half_patch_size:int(torch.round(C_new[1]))+half_patch_size] = \
                          invert_code[:,:,
                                      int(torch.round(C[0]))-half_patch_size:int(torch.round(C[0]))+half_patch_size,
                                      int(torch.round(C[1]))-half_patch_size:int(torch.round(C[1]))+half_patch_size]
    return invert_code_d


