"""
@Description :   绘图
@Author      :   tqychy 
@Time        :   2025/03/20 15:38:56
"""
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch_geometric.utils import to_dense_adj, to_undirected

matplotlib.use("Agg")
plt.rcParams["axes.unicode_minus"]=False
plt.rcParams["figure.figsize"] = (15.2, 8.8)

# 设置全局字体大小
fontsize = 20
plt.rcParams.update({
    'font.size': fontsize,
    'axes.titlesize': fontsize,
    'axes.labelsize': fontsize,
    'xtick.labelsize': fontsize,
    'ytick.labelsize': fontsize,
    'legend.fontsize': fontsize - 8,
    'figure.titlesize': fontsize
})



def vis_pairing_result(e_pred, e_gt, sim_mat, v_num, save_path):
    """
    可视化预测结果
    Args:
        e_pred (torch.Tensor): 预测的边, shape: [N, 2]
        e_gt (torch.Tensor): 真实边, shape: [M, 2]
        v_num (int): 节点数
        save_path (str): 保存路径
    Returns:
        prec (float): 预测边的 precision
        rec (float): 预测边的 recall
        f1 (float): 预测边的 f1 score
    """
    e_np = to_dense_adj(to_undirected(e_gt.T), max_num_nodes=v_num).squeeze().T.numpy().astype(int)
    pred_np = to_dense_adj(to_undirected(e_pred.T), max_num_nodes=v_num).squeeze().T.numpy().astype(int)
    sim_mat = sim_mat.cpu().detach().numpy()

    # Calculate metrics
    TP = np.sum((pred_np == 1) & (e_np == 1))
    FP = np.sum((pred_np == 1) & (e_np == 0))
    FN = np.sum((pred_np == 0) & (e_np == 1))

    precision = TP / (TP + FP) if (TP + FP) != 0 else 0
    recall = TP / (TP + FN) if (TP + FN) != 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

    # Create prediction visualization matrix
    n = pred_np.shape[0]
    colors = np.ones((n, n, 3))  # White background

    # Set colors based on conditions
    colors[(pred_np == 1) & (e_np == 1)] = [
        0, 1, 0]   # Green - True Positive
    colors[(pred_np == 1) & (e_np == 0)] = [
        1, 0, 0]   # Red - False Positive
    colors[(pred_np == 0) & (e_np == 1)] = [
        0.5, 0, 0.5]  # Purple - False Negative

    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3)

    # Plot similarity matrix
    im1 = axes[0].imshow(sim_mat, cmap='coolwarm')
    # 使用 make_axes_locatable 为 colorbar 创建适配的轴
    divider = make_axes_locatable(axes[0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im1, cax=cax)
    axes[0].set_title("Similarity Matrix")

    # Plot prediction results
    axes[1].imshow(colors)
    axes[1].set_title(
        f"Prediction (F1: {f1_score:.2f}, Prec: {precision:.2f}, Rec: {recall:.2f})")

    # Plot ground truth
    im3 = axes[2].imshow(e_np, cmap='gray', vmin=0, vmax=1)
    axes[2].set_title("Ground Truth")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.clf()
    plt.close(fig)

    return precision, recall, f1_score


def vis_distribution(e_pred, scores, e_gt, save_path):
    """
    可视化精细匹配正确 or 错误边分数分布
    Args:
        e_pred (torch.Tensor): 预测的边, shape: [N, 2]
        scores (torch.Tensor): 边分数, shape: [N]
        e_gt (torch.Tensor): 真实边, shape: [M, 2]
        save_path (str): 保存路径
    """
    true_scores = []
    false_scores = []
    gt_set = set([(u, v) if u < v else (v, u) for u, v in e_gt.tolist()])

    for (u, v), score in zip(e_pred.tolist(), scores.tolist()):
        if u >= v:
            u, v = v, u
        if (u, v) in gt_set:
            true_scores.append(score)
        else:
            false_scores.append(score)
    
    plt.figure(figsize=(10, 6))
    # 计算均值
    mean_true = np.mean(true_scores)
    mean_false = np.mean(false_scores)

    # 绘制核密度估计图并添加均值标签
    sns.kdeplot(true_scores, color='green',
                label=f'True Scores (mean={mean_true:.3f})')
    sns.kdeplot(false_scores, color='purple',
                label=f'False Scores (mean={mean_false:.3f})')

    # 绘制均值竖线
    plt.axvline(mean_true, color='green',
                linestyle='--', linewidth=2, alpha=0.7)
    plt.axvline(mean_false, color='purple',
                linestyle='--', linewidth=2, alpha=0.7)

    # 设置图表属性
    plt.xlabel('Score')
    plt.ylabel('Density')
    plt.title('Distributions of Scores')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    plt.clf()

def build_pic(imgs, pcds, transformations):
    # 收集所有变换后的点云坐标以确定画布尺寸
    all_points = []
    centers = {}
    for idx in transformations:
        T = transformations[idx]
        pcd = pcds[idx]
        pcd = np.hstack((pcd[:, 1].reshape(-1, 1),
                        pcd[:, 0].reshape(-1, 1)))
        homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
        transformed = np.matmul(homogeneous, T.T)
        # 计算该碎片的中心点（x, y 坐标的均值）
        center = np.mean(transformed[:, :2], axis=0)
        centers[idx] = center
        all_points.append(transformed)
    all_points = np.concatenate(all_points, axis=0)

    min_x, max_x = np.min(all_points[:, 0]), np.max(all_points[:, 0])
    min_y, max_y = np.min(all_points[:, 1]), np.max(all_points[:, 1])

    # 计算调整后的中心点，使其相对于画布坐标系
    adjusted_centers = {idx: (
        center - np.array([min_x, min_y])).astype(np.int32) for idx, center in centers.items()}

    # 计算画布的尺寸
    canvas_width = int(np.ceil(max_x - min_x))
    canvas_height = int(np.ceil(max_y - min_y))
    canvas_size = max(canvas_height, canvas_width)
    canvas = np.zeros((canvas_size, canvas_size, 3), dtype=np.float32)

    # 处理每个碎片，叠加到画布
    for i, idx in enumerate(transformations):
        T = transformations[idx]
        adjusted_T = np.array([
            [T[0, 0], T[0, 1], T[0, 2] - min_x],
            [T[1, 0], T[1, 1], T[1, 2] - min_y]
        ], dtype=np.float32)
        img = imgs[idx]
        pcd = pcds[idx]
        pcd = np.hstack((pcd[:, 1].reshape(-1, 1),
                        pcd[:, 0].reshape(-1, 1)))
        homogeneous = np.hstack([pcd, np.ones((len(pcd), 1))])
        transformed = np.matmul(homogeneous, adjusted_T.T)
        transformed_img = cv2.warpAffine(img, adjusted_T, (canvas_size, canvas_size),
                                         flags=cv2.INTER_LINEAR,
                                         borderMode=cv2.BORDER_CONSTANT,
                                         borderValue=(0, 0, 0))
        # 加粗边缘
        for m in range(len(transformed)):
            cv2.circle(canvas, tuple(
                transformed[m].astype(int)), 2, (255, 255, 255), -1)
        canvas += transformed_img

        canvas = np.clip(canvas, 0, 255)

    # 画出碎片中心点
    for _, center in adjusted_centers.items():
        x, y = int(center[0]), int(center[1])
        if 0 <= x < canvas_size and 0 <= y < canvas_size:
            cv2.circle(canvas, center, radius=6,
                       color=(0, 0, 0), thickness=6)

    # 转换为uint8类型
    canvas = canvas.astype(np.uint8)
    return canvas, adjusted_centers

def draw_dashed_line(img, pt1, pt2, color, thickness=1, line_length=5, space_length=5):
    """
    在图像上绘制虚线。
    Args:
        img (np.ndarray): 要绘制虚线的图像。
        pt1 (tuple): 虚线的起点坐标 (x1, y1)。
        pt2 (tuple): 虚线的终点坐标 (x2, y2)。
        color (tuple): 虚线的颜色，例如 (B, G, R)。
        thickness (int, optional): 虚线的线宽，默认为 1。
        line_length (int, optional): 每段实线的长度，默认为 5。
        space_length (int, optional): 每段实线之间的间隔长度，默认为 5。
    """
    dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
    dir_vec = ((pt2[0] - pt1[0]) / dist, (pt2[1] - pt1[1]) / dist)

    for i in range(0, int(dist), line_length + space_length):
        start_x = int(pt1[0] + dir_vec[0] * i)
        start_y = int(pt1[1] + dir_vec[1] * i)
        end_x = int(start_x + dir_vec[0] * min(line_length, dist - i))
        end_y = int(start_y + dir_vec[1] * min(line_length, dist - i))
        cv2.line(img, (start_x, start_y), (end_x, end_y), color, thickness)