import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from scipy.linalg import svd
import numpy.linalg as la
from tqdm import tqdm


def generate_problem_partial(sigma, n, p, d, r, N):
    """
    生成N个部分观测的同步问题

    参数:
    sigma: 噪声强度
    n: 矩阵数量
    p: 观测概率
    d: 每个矩阵的行维度(假设所有d_i相等)
    r: 每个矩阵的列维度
    N: 生成的问题数量

    返回:
    problems: 包含N个问题的列表，每个问题是一个字典，包含:
        - 'C_star': 部分观测矩阵 (nd × nd)
        - 'C_full': 完整观测矩阵 (nd × nd)，用于比较
        - 'G_true': 真实旋转矩阵 (nd × r)
        - 'sigma': 噪声强度
        - 'p': 观测概率
        - 'n', 'd', 'r': 问题参数
    """
    problems = []
    nd = n * d

    for _ in range(N):
        # 生成真实旋转矩阵G
        G_true = np.zeros((nd, r))
        for i in range(n):
            G_i = np.random.randn(d, r)
            U, _, Vt = svd(G_i, full_matrices=False)
            G_i_ortho = U @ Vt
            G_true[i * d:(i + 1) * d, :] = G_i_ortho

        # 生成完整的噪声矩阵W
        W_full = np.zeros((nd, nd))
        upper_tri_indices = np.triu_indices(nd)
        W_full[upper_tri_indices] = np.random.randn(len(upper_tri_indices[0]))
        W_full = W_full + W_full.T - np.diag(np.diag(W_full))

        # 生成完整的观测矩阵C
        C_full = G_true @ G_true.T + sigma * W_full

        # 生成观测掩码矩阵E
        E = np.zeros((n, n))
        for i in range(n):
            for j in range(i, n):
                e_ij = np.random.binomial(1, p)
                E[i, j] = e_ij
                E[j, i] = e_ij  # 确保对称

        # 生成部分观测的C_star
        C_star = np.zeros((nd, nd))
        for i in range(n):
            for j in range(n):
                if E[i, j] == 1:
                    C_star[i * d:(i + 1) * d, j * d:(j + 1) * d] = C_full[i * d:(i + 1) * d, j * d:(j + 1) * d]

        problem = {
            'C_star': C_star,
            'C_full': C_full,  # 保存完整矩阵用于比较
            'G_true': G_true,
            'E': E,  # 保存掩码矩阵
            'sigma': sigma,
            'p': p,
            'n': n,
            'd': d,
            'r': r
        }
        problems.append(problem)

    return problems


def project_stiefel(X):
    """将矩阵投影到Stiefel流形上"""
    U, s, Vt = svd(X, full_matrices=False)
    return U @ Vt


def project_stiefel_block(X, n, d, r):
    """对分块矩阵进行投影，每个块单独投影到Stiefel流形"""
    Y = np.zeros_like(X)
    for i in range(n):
        X_i = X[i * d:(i + 1) * d, :]
        Y_i = project_stiefel(X_i)
        Y[i * d:(i + 1) * d, :] = Y_i
    return Y


def GPM(problem, max_iter=400, tol=1e-8):
    """
    使用广义幂方法求解部分观测OTSM问题

    参数:
    problem: 包含问题数据的字典
    max_iter: 最大迭代次数
    tol: 收敛容忍度

    返回:
    O: 求解得到的矩阵 (n*d × r)
    """
    C_star = problem['C_star']
    n = problem['n']
    d = problem['d']
    r = problem['r']
    nd = n * d

    # 步骤1: 计算C_star的前r个特征向量
    eigenvalues, eigenvectors = la.eigh(C_star)
    # 取最大的r个特征值对应的特征向量
    idx = np.argsort(eigenvalues)[::-1][:r]
    tilde_G = eigenvectors[:, idx]

    # 确保tilde_Gᵀtilde_G = nI_r
    scale = np.sqrt(n) / la.norm(tilde_G, 'fro')
    tilde_G = tilde_G * scale

    # 步骤2: 初始化
    O_prev = project_stiefel_block(tilde_G, n, d, r)

    # 步骤3: 迭代
    for t in range(max_iter):
        # 计算C_star * O_prev
        CG = C_star @ O_prev

        # 投影到Stiefel流形
        O_next = project_stiefel_block(CG, n, d, r)

        # 检查收敛性: ||O_next - O_prev||_F / sqrt(n) < tol
        diff_norm = la.norm(O_next - O_prev, 'fro') / np.sqrt(n)
        if diff_norm < tol:
            break

        O_prev = O_next

    return O_next


def is_global_optimal_partial(O, C_star, n, d, r):
    """
    判断解O是否是部分观测问题的全局最优解

    参数:
    O: 候选解 (n*d × r)
    C_star: 部分观测矩阵
    n, d, r: 问题参数

    返回:
    is_optimal: 是否是全局最优解
    """
    nd = n * d

    # 计算Lambda_i矩阵
    Lambda_matrices = []
    tau_values = []

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]

        # 计算(C_star * O)_i
        CO_i = C_star[i * d:(i + 1) * d, :] @ O

        # 计算Lambda_i = O_iᵀ(CO)_i
        Lambda_i = O_i.T @ CO_i

        # 确保Lambda_i对称
        Lambda_i = (Lambda_i + Lambda_i.T) / 2

        Lambda_matrices.append(Lambda_i)

        # 计算tau_i (Lambda_i的最小特征值)
        eigenvalues = la.eigvalsh(Lambda_i)
        tau_i = np.min(eigenvalues)
        tau_values.append(tau_i)

    # 构造L(O, Lambda)矩阵
    L_matrix = np.zeros((nd, nd))

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]
        Lambda_i = Lambda_matrices[i]
        tau_i = tau_values[i]

        # 计算O_i Lambda_i O_iᵀ + tau_i (I - O_i O_iᵀ)
        L_block = O_i @ Lambda_i @ O_i.T + tau_i * (np.eye(d) - O_i @ O_i.T)
        L_matrix[i * d:(i + 1) * d, i * d:(i + 1) * d] = L_block

    # L(O, Lambda) = 上述块对角矩阵 - C_star
    L_matrix = L_matrix - C_star

    # 检查L是否半正定: 最小特征值 >= -1e-5 * |最大特征值|
    eigenvalues_L = la.eigvalsh(L_matrix)
    min_eig = np.min(eigenvalues_L)
    max_eig_abs = np.max(np.abs(eigenvalues_L))

    # 如果最小特征值 >= -1e-5 * |最大特征值|，则认为半正定
    is_optimal = (min_eig >= -1e-5 * max_eig_abs)

    return is_optimal


def success_rate(sigma, n, p, d, r, N):
    """
    计算GPM成功找到全局最优解的比例

    参数:
    sigma: 噪声强度
    n: 矩阵数量
    p: 观测概率
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 测试问题数量

    返回:
    success_ratio: 成功比例
    """
    # 生成部分观测问题
    problems = generate_problem_partial(sigma, n, p, d, r, N)

    success_count = 0

    for i, problem in enumerate(problems):
        try:
            # 使用GPM求解
            O_solution = GPM(problem)

            # 判断是否是全局最优
            C_star = problem['C_star']
            if is_global_optimal_partial(O_solution, C_star, n, d, r):
                success_count += 1

        except Exception as e:
            print(f"问题 {i + 1} 求解出错: {e}")
            continue

    success_ratio = success_count / N
    return success_ratio


def compute_success_grid(d, r, N, sigma_low, sigma_high, sigma_number,
                         n_low, n_high, n_number, constant, filename="success_grid_partial.pkl"):
    """
    计算部分观测问题的成功率网格并保存结果

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    constant: 常数c，用于计算p = min(c * log(n)/sqrt(n), 1)
    filename: 保存结果的文件名

    返回:
    results: 包含(sigma, n*p, success_rate)的列表
    """
    # 检查是否已有保存的结果
    if os.path.exists(filename):
        print(f"加载已保存的结果: {filename}")
        with open(filename, 'rb') as f:
            results = pickle.load(f)
        return results

    # 在对数尺度上生成σ和n的值
    sigma_log_low = np.log10(sigma_low)
    sigma_log_high = np.log10(sigma_high)
    sigma_log_values = np.linspace(sigma_log_low, sigma_log_high, sigma_number)
    sigma_values = 10 ** sigma_log_values

    n_log_low = np.log10(n_low)
    n_log_high = np.log10(n_high)
    n_log_values = np.linspace(n_log_low, n_log_high, n_number)
    n_values = 10 ** n_log_values
    n_values = [int(max(1, np.round(n))) for n in n_values]  # n必须是整数，最小为1
    n_values = sorted(set(n_values))  # 去重并排序

    print(f"计算网格: σ从{sigma_low}到{sigma_high}，共{len(sigma_values)}个点")
    print(f"n从{n_low}到{n_high}，共{len(n_values)}个点(去重后)")
    print(f"常数c = {constant}")

    results = []
    total_points = len(sigma_values) * len(n_values)

    # 使用进度条
    with tqdm(total=total_points, desc="计算部分观测问题的成功率网格") as pbar:
        for sigma in sigma_values:
            for n in n_values:
                # 计算观测概率 p = min(c * log(n)/sqrt(n), 1)
                p = min(constant * np.log(n) / np.sqrt(n), 1.0)

                # 计算成功率
                success_ratio = success_rate(sigma, n, p, d, r, N)

                # 存储(sigma, n*p, success_ratio)
                results.append((sigma, n * p, success_ratio))
                pbar.update(1)

    # 保存结果
    with open(filename, 'wb') as f:
        pickle.dump(results, f)

    print(f"结果已保存到: {filename}")
    return results


def visualize_success_grid_partial(results, d, r, constant, save_path="success_heatmap_partial.png"):
    """
    可视化部分观测问题的成功率网格

    参数:
    results: 包含(sigma, n*p, success_rate)的列表
    d, r: 问题参数
    constant: 常数c
    save_path: 保存图像的文件路径
    """
    if not results:
        print("没有结果可可视化")
        return None, None, None

    # 提取数据
    sigmas = np.array([r[0] for r in results])
    n_ps = np.array([r[1] for r in results])
    success_rates = np.array([r[2] for r in results])

    # 获取唯一的σ和n*p值
    unique_sigmas = np.unique(sigmas)
    unique_n_ps = np.unique(n_ps)

    # 创建二维网格
    sigma_grid, n_p_grid = np.meshgrid(unique_sigmas, unique_n_ps)
    success_grid = np.zeros_like(sigma_grid, dtype=float)

    # 填充成功率网格
    for sigma, n_p, rate in results:
        i = np.where(unique_sigmas == sigma)[0][0]
        j = np.where(unique_n_ps == n_p)[0][0]
        success_grid[j, i] = rate

    # 创建图形
    # plt.figure(figsize=(10, 8))
    plt.figure(figsize=(10, 6))

    # 计算网格边缘（在对数尺度上）
    log_sigma = np.log10(unique_sigmas)
    log_n_p = np.log10(unique_n_ps)

    # 计算边缘位置
    if len(log_sigma) > 1:
        sigma_step = (log_sigma[-1] - log_sigma[0]) / (len(log_sigma) - 1)
        log_sigma_edges = np.concatenate([
            [log_sigma[0] - sigma_step / 2],
            (log_sigma[:-1] + log_sigma[1:]) / 2,
            [log_sigma[-1] + sigma_step / 2]
        ])
    else:
        log_sigma_edges = [log_sigma[0] - 0.5, log_sigma[0] + 0.5]

    if len(log_n_p) > 1:
        n_p_step = (log_n_p[-1] - log_n_p[0]) / (len(log_n_p) - 1)
        log_n_p_edges = np.concatenate([
            [log_n_p[0] - n_p_step / 2],
            (log_n_p[:-1] + log_n_p[1:]) / 2,
            [log_n_p[-1] + n_p_step / 2]
        ])
    else:
        log_n_p_edges = [log_n_p[0] - 0.5, log_n_p[0] + 0.5]

    # 绘制热力图
    im = plt.pcolormesh(
        log_n_p_edges,
        log_sigma_edges,
        success_grid.T,
        cmap='gray_r',
        shading='flat',
        vmin=0,
        vmax=1,
        edgecolors='none'
    )

    # 添加颜色条
    cbar = plt.colorbar(im, label='Success Rate ')
    cbar.set_label('Success Rate', fontsize=25)  # 设置字体大小为16
    cbar.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
    cbar.set_ticklabels(['0', '25', '50', '75', '100'],fontsize=25)

    # # 添加蓝线: σ = sqrt(n*p)  => log(σ) = 0.5 * log(n*p)
    # n_p_line = np.logspace(np.log10(min(unique_n_ps)), np.log10(max(unique_n_ps)), 100)
    # sigma_line = np.sqrt(n_p_line)
    # plt.plot(np.log10(n_p_line), np.log10(sigma_line), 'b-', linewidth=2.5,
    #          label=r'$\sigma = \sqrt{np}$')

    # 修改后的代码：将蓝线向下平移2.5个单位
    n_p_line = np.logspace(np.log10(min(unique_n_ps)), np.log10(max(unique_n_ps)), 100)
    sigma_line = np.sqrt(n_p_line)
    plt.plot(np.log10(n_p_line), np.log10(sigma_line) - 0.8, 'r-', linewidth=3,
             label=r'$\sigma = 0.158\sqrt{np}$')
    plt.legend(prop={'size': 25})  # 设置字体大小和粗细

    # 设置坐标轴
    plt.xlabel('$np$ (log-scale)', fontsize=25)
    plt.ylabel('Noise level $\sigma$ (log-scale)', fontsize=25)
    # plt.title(f'Random Observation: How often does GPM converge to a global optimum?\n(d={d}, r={r})',  fontsize=18)

    # 设置X轴刻度
    min_n_p, max_n_p = min(unique_n_ps), max(unique_n_ps)

    # 生成合适的X轴刻度
    n_p_ticks = []
    tick = 1
    while tick < min_n_p:
        tick *= 10
    while tick <= max_n_p * 1.5:
        n_p_ticks.append(tick)
        tick *= 5

    plt.xticks(np.log10(n_p_ticks), [f'{t:.0f}'  for t in n_p_ticks])

    # 设置Y轴刻度
    sigma_ticks = []
    tick = 0.1
    while tick < min(unique_sigmas):
        tick *= 2
    while tick <= max(unique_sigmas) * 1.5:
        sigma_ticks.append(tick)
        tick *= 2

    plt.yticks(np.log10(sigma_ticks), [f'{t:.1f}' if t < 1 else f'{int(t)}' for t in sigma_ticks])
    plt.tick_params(axis='both', labelsize=25)  # 同时设置x轴和y轴

    # 设置坐标轴范围
    plt.xlim(min(log_n_p_edges), max(log_n_p_edges))
    plt.ylim(min(log_sigma_edges), max(log_sigma_edges))


    plt.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    plt.tight_layout()

    # 保存图像
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图像已保存到: {save_path}")

    plt.show()

    return success_grid, unique_sigmas, unique_n_ps


def compute_and_visualize_partial(d, r, N, constant, sigma_low, sigma_high, sigma_number,
                                  n_low, n_high, n_number):
    """
    计算并可视化部分观测问题的成功率网格

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    constant: 常数c，用于计算p = min(c * log(n)/sqrt(n), 1)
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    """
    # 创建文件名
    filename = f"success_grid_partial_d{d}_r{r}_N{N}_c{constant}.pkl"

    # 计算成功率网格
    results = compute_success_grid(d, r, N, sigma_low, sigma_high, sigma_number,
                                   n_low, n_high, n_number, constant, filename)

    # 可视化
    success_grid, sigmas, n_ps = visualize_success_grid_partial(results, d, r, constant)

    # 分析结果
    print("\n结果分析:")
    print(f"σ范围: {min(sigmas):.2f} 到 {max(sigmas):.2f}，共{len(sigmas)}个点")
    print(f"np范围: {min(n_ps):.2f} 到 {max(n_ps):.2f}，共{len(n_ps)}个点")
    print(f"成功率网格形状: {success_grid.shape}")



    return results, success_grid, sigmas, n_ps


# 使用示例
if __name__ == "__main__":
    # 参数设置
    d = 10  # 每个矩阵的行维度
    r = 3  # 每个矩阵的列维度
    N = 10  # 每个参数组合测试的问题数量（测试时用较小的N）
    constant = 2.0  # 常数c，用于计算p = min(c * log(n)/sqrt(n), 1)

    # σ和n的范围和点数
    sigma_low = 0.1
    sigma_high = 100
    sigma_number = 100

    n_low = 10
    n_high = 1000
    n_number = 200

    print("开始计算部分观测问题的成功率网格...")
    print(f"参数: d={d}, r={r}, N={N}, c={constant}")
    print(f"σ范围: {sigma_low} 到 {sigma_high}, 点数: {sigma_number}")
    print(f"n范围: {n_low} 到 {n_high}, 点数: {n_number}")
    print(f"观测概率公式: p = min({constant} * log(n)/√n, 1)")

    # 计算并可视化
    results, success_grid, sigmas, n_ps = compute_and_visualize_partial(
        d, r, N, constant, sigma_low, sigma_high, sigma_number, n_low, n_high, n_number
    )