#!/usr/bin/env python3
"""
ex5.2
Spatial Resource Allocation Problem (Section 5.2)
根据推论4.1，计算y_i的梯度并进行梯度下降优化，同时给出可视化方案。
"""
import numpy as np
import matplotlib.pyplot as plt
import time


def optimize_y_unbalanced(
        X, y_init, M_mu=1.0,
        a1=1.0, a2=1.0, b1=1.0, b2=1.0,   # ★ TV / 截断系数 a1,a2,b1,b2 (§3 折线惩罚)
        alpha0=0.05, decay=0.6,
        T_max=1000, tol=1e-4, k=1.0        # ★ k = c 的前系数（可调，默认为1）
    ):
    N, m = X.shape[0], y_init.shape[0]
    y  = y_init.copy()
    w  = np.zeros(m)                       # ★ 初始化权重 w_i=0（弱-不平衡OT中的内层最优权重）
    history = {'y': [], 'obj': [], 'muR': []}
    #start_time = time.time()
    for t in range(T_max):
        history['y'].append(y.copy())

        # ---------- 1. 分区（拉格朗日图） ----------
        diff  = X[:, None, :] - y[None, :, :]        # (N,m,2)
        dist  = np.linalg.norm(diff, axis=2)         # c(x_n,y_i)=||x−y||
        # ★ 加权距离：z(x,i)=k·c(x,y_i) − w_i    （对应论文 (17) 中的有效代价）
        z_all = k * dist - w[None, :]                # (N,m)
        labels = np.argmin(z_all, axis=1)            # π(x_n)=argmin_i z(x_n,i)

        # ---------- 2. 残差集 R 的判断 ----------
        z_min      = z_all[np.arange(N), labels]     # z(x_n,π(x_n))
        mask_valid = (z_min <  a1)                   # ★ 有效带：z∈(−b1,a1)
        if t == 0:
            history['mask_initial'] = mask_valid.copy()  # ★ record initial mask_valid
        R_idx      = ~mask_valid                     # 残差样本
        # ★ 残差质量 μ(R) 的蒙特卡洛估计
        mu_R       = (M_mu / N) * R_idx.sum()        # μ(R)
        history['muR'].append(mu_R)

        # ---------- 3. 权重 w 更新 (式 19) ----------
        counts = np.bincount(labels[mask_valid], minlength=m)
        # ★ 空区 → w_i = -b2，否则 w_i = 0    （对应式 (19) 的二值权重）
        w[counts == 0] = -b2
        w[counts >  0] = 0

        # ---------- 4. 目标函数 Q(y) = 运输 + TV ----------
        # 运输部分：只对有效带求和
        obj_transport = (M_mu / N) * z_min[mask_valid].sum()
        # ★ TV 惩罚 = a1 * μ(R)
        obj_tv        = a1 * mu_R
        obj           = obj_transport + obj_tv
        history['obj'].append(obj)

        # ---------- 5. 梯度计算（推论4.1） ----------
        grad = np.zeros_like(y)
        for i in range(m):
            # ★ 只累积有效带样本
            idx = (labels == i) & mask_valid
            Xi  = X[idx]
            if Xi.shape[0] > 0:
                di      = Xi - y[i]                   # (ni,2)
                ri      = np.linalg.norm(di, axis=1, keepdims=True)
                ri[ri==0]= 1e-8
                # 梯度：-k * ∫_有效带 (x−y_i)/||x−y_i|| dμ(x)
                grad[i] = -k * (M_mu / N) * np.sum(di/ri, axis=0)

        # ---------- 6. 更新 y ----------
        alpha = alpha0 / ((1 + t) ** decay)
        y    = y - alpha * grad

        # ---------- 7. 收敛判定 ----------
        if np.linalg.norm(grad) < tol:
            break

    history['y'].append(y.copy())
    history['mask_final'] = mask_valid.copy()            # ★ record final mask_valid
    #end_time = time.time()
    #print(f"运行时间: {end_time - start_time:.4f} 秒")
    return y, history
    


def visualize(X, history):
    """
    可视化:
    1. 目标函数与残差质量收敛曲线
    2. 初始/最终 Voronoi-like 分区 (颜色与 y_i 一致)
    3. y_i 轨迹 (颜色与分区一致)
    """
    # ---------- 0. 颜色准备 ----------
    m = history['y'][0].shape[0]
    colors = [f'C{i}' for i in range(m)]

    # ---------- 1. 收敛曲线 ----------
    fig, ax1 = plt.subplots(figsize=(6, 6))
    ax1.plot(history['obj'], '-o', color='C0', label='Objective Q')
    ax1.set_xlabel('Iterations', fontsize=18)
    ax1.set_ylabel('Objective Q', color='C0', fontsize=18)
    ax1.tick_params(axis='x', labelsize=18)
    ax1.tick_params(axis='y', labelcolor='C0', labelsize=18)
    ax1.grid(True, linestyle='--', alpha=0.3)

    ax2 = ax1.twinx()
    ax2.plot(history['muR'], '-s', color='C1', label='Residual Mass μ(R)')
    ax2.set_ylabel('Residual Mass μ(R)', color='C1', fontsize=18)
    ax2.tick_params(axis='y', labelcolor='C1', labelsize=18)

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2,
               fontsize=14, loc='upper right')
    plt.tight_layout()

    # ---------- 2. 初始 / 最终分区 ----------
    for stage, y_pts in [('Initial', history['y'][0]),
                         ('Final',   history['y'][-1])]:

        mask_valid = (history['mask_initial']
                      if stage == 'Initial' else history['mask_final'])

        diff = X[:, None, :] - y_pts[None, :, :]
        dist = np.linalg.norm(diff, axis=2)
        labels = np.argmin(dist, axis=1)

        plt.figure(figsize=(6, 6))
        # 2-A. 按 y_i 上色的有效样本
        for i in range(m):
            idx = (labels == i) & mask_valid
            plt.scatter(X[idx, 0], X[idx, 1],
                        s=10, alpha=0.5, color=colors[i])

        # 2-B. 残差集 R
        plt.scatter(X[~mask_valid, 0], X[~mask_valid, 1],
                    c='gray', marker='x', s=20, label='Residual R')

        # 2-C. 质心
        plt.scatter(y_pts[:, 0], y_pts[:, 1],
            c='k', marker='x', s=200, linewidths=4,
            label='$y_i$')

        plt.title(f'{stage} Partition', fontsize=18)
        plt.xlabel('$x_1$', fontsize=18)
        plt.ylabel('$x_2$', fontsize=18)
        plt.legend(fontsize=14, loc='lower right')
        plt.tick_params(axis='both', which='major', labelsize=18)
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.tight_layout()

    # ---------- 3. y_i 轨迹 ----------
    ys = np.stack(history['y'], axis=0)           # (T+1, m, 2)
    plt.figure(figsize=(6, 6))
    for i in range(m):
        traj = ys[:, i, :]
        plt.plot(traj[:, 0], traj[:, 1], '-o',
                 color=colors[i], label=f'$y_{i}$')
        # 轨迹箭头
        for j in range(len(traj) - 1):
            plt.annotate('', xy=traj[j+1], xytext=traj[j],
                         arrowprops=dict(arrowstyle='->',
                                         color=colors[i], lw=1))
    # 最终收敛位置空心黑圈
    finals = ys[-1]
    plt.scatter(finals[:, 0], finals[:, 1],
                facecolors='none', edgecolors='k',
                s=250, linewidths=2, label='Final $y$')

    plt.title('Trajectory of $y_i$', fontsize=18)
    plt.xlabel('$x_1$', fontsize=18)
    plt.ylabel('$x_2$', fontsize=18)
    plt.legend(fontsize=14)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()


def main():
    # 生成数据
    N = 2000
    m = 4
    M_mu = 1.0
    rng = np.random.default_rng(1)
    # 采样非均匀连续分布：两簇高斯混合
    means = np.array([[0.3, 0.3], [0.7, 0.7]])
    covs = [0.02 * np.eye(2), 0.02 * np.eye(2)]
    half = N // 2
    X1 = rng.multivariate_normal(means[0], covs[0], size=half)
    X2 = rng.multivariate_normal(means[1], covs[1], size=N-half)
    X = np.vstack([X1, X2])
    # 限制到 [0,1]^2 区域内
    X = np.clip(X, 0, 1)
    rng_yinit = np.random.default_rng(10000) #10
# 只影响 y_init 的初始化
    y_init = rng_yinit.random((m, 2))
    #y_init = rng.random((m, 2))
    # 优化 y（调整参数 a1、b1、b2 以体现不平衡效果）
    y_opt, history = optimize_y_unbalanced(
        X, y_init, M_mu,
        a1=0.3, a2=1.0,    # ★ 较小阈值 a1 会产生残差
        b1=0.1, b2=0.2,    # ★ TV 截断参数
        alpha0=0.3, decay=0.6,
        T_max=3000, tol=1e-3
    )
    # 可视化
    visualize(X, history)


if __name__ == '__main__':
    main() 