import numpy as np
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches

# --- 1. 最终参数优化 ---
N_POINTS = 100              # 大幅增加点的数量使其更密集
N_LINES = 100               # 保持线的数量
LINE_ALPHA = 0.28           # 略提高线的不透明度以增强整体厚度
POINT_ALPHA = 0.95          # 进一步提高点的透明度
LINE_WIDTH = 2.4            # 略加粗线条，增强下图的“粗壮感”

COLORS = {
    'purple': '#8A6FDF',
    'blue': '#3A82E4',
    'red': '#F06E64' # 沿用红色，但视觉上是橘色
}

# --- 2. 数据生成函数 (无变化) ---
def generate_cluster(center, cov, n_points):
    return np.random.multivariate_normal(center, cov, n_points)

# --- 3. 绘图函数 (调整曲率计算方式，使用传入的因子决定凹凸) ---
def plot_flow_matching(ax, title, right_blue_center, right_red_center, blue_curvature_factor, red_curvature_factor, is_contrastive=False):
    """
    根据详细参数绘制流匹配图。
    blue_curvature_factor 和 red_curvature_factor 现在直接通过其符号控制凹凸：
    负值表示向内凹（控制点在连线下方），正值表示向外凸（控制点在连线上方）。
    is_contrastive: 如果为True，则流线从重叠开始然后分叉
    """
    # 关键改动：大幅减小协方差，让点云更加密集
    left_cov = [[0.001, 0], [0, 0.01]]   # 左侧点云更密集
    right_cov = [[0.001, 0], [0, 0.01]]  # 右侧点云更密集
    
    left_center = [-1.5, 0]

    # 生成数据
    points_left = generate_cluster(left_center, left_cov, N_POINTS)
    points_right_blue = generate_cluster(right_blue_center, right_cov, N_POINTS)
    points_right_red = generate_cluster(right_red_center, right_cov, N_POINTS)

    # 绘制散点
    ax.scatter(points_left[:, 0], points_left[:, 1], c=COLORS['purple'], s=10, alpha=POINT_ALPHA, zorder=2, edgecolors='none')
    ax.scatter(points_right_blue[:, 0], points_right_blue[:, 1], c=COLORS['blue'], s=10, alpha=POINT_ALPHA, zorder=2, edgecolors='none')
    ax.scatter(points_right_red[:, 0], points_right_red[:, 1], c=COLORS['red'], s=10, alpha=POINT_ALPHA, zorder=2, edgecolors='none')

    # 绘制连接曲线
    for i in range(N_LINES):
        # 蓝色曲线
        start_point = points_left[np.random.randint(0, N_POINTS)]
        end_point = points_right_blue[np.random.randint(0, N_POINTS)]
        
        if is_contrastive:
            # 对于对比流匹配：
            # 分叉稍晚一些，且在分叉处“圆筒状”不过度张开：
            # - cp1: 30% 处，y=0（保持重叠到更晚位置）
            # - cp2: 68% 处，y 为终点 y 的 0.45 倍（上下方向分离）
            dx = end_point[0] - start_point[0]
            control_point1 = [start_point[0] + dx * 0.30, 0]
            control_point2 = [start_point[0] + dx * 0.68, end_point[1] * 0.45]
            
            path_data = [(Path.MOVETO, start_point), 
                        (Path.CURVE4, control_point1), 
                        (Path.CURVE4, control_point2), 
                        (Path.CURVE4, end_point)]
        else:
            # 普通流匹配使用原来的二次曲线
            mid_point = (start_point + end_point) / 2
            control_point = [mid_point[0], mid_point[1] + blue_curvature_factor * (end_point[0] - start_point[0])]
            path_data = [(Path.MOVETO, start_point), (Path.CURVE3, control_point), (Path.CURVE3, end_point)]
        
        codes, verts = zip(*path_data)
        path = Path(verts, codes)
        patch = patches.PathPatch(path, facecolor='none', edgecolor=COLORS['blue'], lw=LINE_WIDTH, alpha=LINE_ALPHA, zorder=1)
        ax.add_patch(patch)

        # 红色（橘色）曲线
        start_point = points_left[np.random.randint(0, N_POINTS)]
        end_point = points_right_red[np.random.randint(0, N_POINTS)]
        
        if is_contrastive:
            # 红色曲线相同策略：稍晚分叉，逐步向下分离
            dx = end_point[0] - start_point[0]
            control_point1 = [start_point[0] + dx * 0.30, 0]
            control_point2 = [start_point[0] + dx * 0.68, end_point[1] * 0.45]
            
            path_data = [(Path.MOVETO, start_point), 
                        (Path.CURVE4, control_point1), 
                        (Path.CURVE4, control_point2), 
                        (Path.CURVE4, end_point)]
        else:
            # 普通流匹配使用原来的二次曲线
            mid_point = (start_point + end_point) / 2
            control_point = [mid_point[0], mid_point[1] + red_curvature_factor * (end_point[0] - start_point[0])]
            path_data = [(Path.MOVETO, start_point), (Path.CURVE3, control_point), (Path.CURVE3, end_point)]

        codes, verts = zip(*path_data)
        path = Path(verts, codes)
        patch = patches.PathPatch(path, facecolor='none', edgecolor=COLORS['red'], lw=LINE_WIDTH, alpha=LINE_ALPHA, zorder=1)
        ax.add_patch(patch)

    # 样式化图表
    # ax.set_title(title, fontsize=16, fontweight='bold', pad=10)
    ax.set_xlim(-2.2, 2.2)
    ax.set_ylim(-1.8, 1.8)
    ax.axis('off')

# --- 4. 主程序 ---
# 创建画布
fig, axs = plt.subplots(1, 1, figsize=(8, 5))

# 绘制第一个图: "Flow-Matching" - 两个右侧点云较近，曲线相对平缓
# plot_flow_matching(
#     axs[0], "Flow-Matching",
#     right_blue_center=[1.5, 0.3], right_red_center=[1.5, -0.3],  # 右侧点云距离较近
#     blue_curvature_factor=-0.15, red_curvature_factor=-0.15,  # 曲线弯曲较小
#     is_contrastive=False  # 普通流匹配
# )

# 绘制第二个图: "Contrastive Flow-Matching" - 流线从重叠开始然后分叉
plot_flow_matching(
    axs, "Contrastive Flow-Matching",
    right_blue_center=[1.5, 1.2], right_red_center=[1.5, -1.2], # 右侧点云分得更开
    blue_curvature_factor=-1.5, # 蓝色向上弯曲
    red_curvature_factor=1.5,   # 橘色向下弯曲
    is_contrastive=True  # 对比流匹配，从重叠开始分叉
)

# 调整布局并显示
plt.tight_layout()
# plt.show()
plt.savefig('aa.svg', format='svg', bbox_inches='tight')