import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from scipy.linalg import svd
import numpy.linalg as la
from tqdm import tqdm
import gc


def generate_one_problem_partial(sigma, n, p, d, r):
    """
    生成单个部分观测的同步问题，避免存储多个问题

    参数:
    sigma: 噪声强度
    n: 矩阵数量
    p: 观测概率
    d: 每个矩阵的行维度(假设所有d_i相等)
    r: 每个矩阵的列维度

    返回:
    problem: 包含单个问题的字典
    """
    nd = n * d

    # 生成真实旋转矩阵G
    G_true = np.zeros((nd, r))
    for i in range(n):
        G_i = np.random.randn(d, r)
        U, _, Vt = svd(G_i, full_matrices=False)
        G_i_ortho = U @ Vt
        G_true[i * d:(i + 1) * d, :] = G_i_ortho

    # 生成完整的噪声矩阵W
    W_full = np.zeros((nd, nd))
    upper_tri_indices = np.triu_indices(nd)
    W_full[upper_tri_indices] = np.random.randn(len(upper_tri_indices[0]))
    W_full = W_full + W_full.T - np.diag(np.diag(W_full))

    # 生成完整的观测矩阵C
    C_full = G_true @ G_true.T + sigma * W_full

    # 生成观测掩码矩阵E
    E = np.zeros((n, n))
    for i in range(n):
        for j in range(i, n):
            e_ij = np.random.binomial(1, p)
            E[i, j] = e_ij
            E[j, i] = e_ij  # 确保对称

    # 生成部分观测的C_star
    C_star = np.zeros((nd, nd))
    for i in range(n):
        for j in range(n):
            if E[i, j] == 1:
                C_star[i * d:(i + 1) * d, j * d:(j + 1) * d] = C_full[i * d:(i + 1) * d, j * d:(j + 1) * d]

    return {
        'C_star': C_star,
        'C_full': C_full,  # 保存完整矩阵用于比较
        'G_true': G_true,
        'E': E,  # 保存掩码矩阵
        'sigma': sigma,
        'p': p,
        'n': n,
        'd': d,
        'r': r
    }


def project_stiefel(X):
    """将矩阵投影到Stiefel流形上"""
    U, s, Vt = svd(X, full_matrices=False)
    return U @ Vt


def project_stiefel_block(X, n, d, r):
    """对分块矩阵进行投影，每个块单独投影到Stiefel流形"""
    Y = np.zeros_like(X)
    for i in range(n):
        X_i = X[i * d:(i + 1) * d, :]
        Y_i = project_stiefel(X_i)
        Y[i * d:(i + 1) * d, :] = Y_i
    return Y


def GPM(problem, max_iter=400, tol=1e-8):
    """
    使用广义幂方法求解部分观测OTSM问题

    参数:
    problem: 包含问题数据的字典
    max_iter: 最大迭代次数
    tol: 收敛容忍度

    返回:
    O: 求解得到的矩阵 (n*d × r)
    """
    C_star = problem['C_star']
    n = problem['n']
    d = problem['d']
    r = problem['r']
    nd = n * d

    # 步骤1: 计算C_star的前r个特征向量
    eigenvalues, eigenvectors = la.eigh(C_star)
    # 取最大的r个特征值对应的特征向量
    idx = np.argsort(eigenvalues)[::-1][:r]
    tilde_G = eigenvectors[:, idx]

    # 确保tilde_Gᵀtilde_G = nI_r
    scale = np.sqrt(n) / la.norm(tilde_G, 'fro')
    tilde_G = tilde_G * scale

    # 步骤2: 初始化
    O_prev = project_stiefel_block(tilde_G, n, d, r)

    # 步骤3: 迭代
    for t in range(max_iter):
        # 计算C_star * O_prev
        CG = C_star @ O_prev

        # 投影到Stiefel流形
        O_next = project_stiefel_block(CG, n, d, r)

        # 检查收敛性: ||O_next - O_prev||_F / sqrt(n) < tol
        diff_norm = la.norm(O_next - O_prev, 'fro') / np.sqrt(n)
        if diff_norm < tol:
            break

        O_prev = O_next

    return O_next


def is_global_optimal_partial(O, C_star, n, d, r):
    """
    判断解O是否是部分观测问题的全局最优解

    参数:
    O: 候选解 (n*d × r)
    C_star: 部分观测矩阵
    n, d, r: 问题参数

    返回:
    is_optimal: 是否是全局最优解
    """
    nd = n * d

    # 计算Lambda_i矩阵
    Lambda_matrices = []
    tau_values = []

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]

        # 计算(C_star * O)_i
        CO_i = C_star[i * d:(i + 1) * d, :] @ O

        # 计算Lambda_i = O_iᵀ(CO)_i
        Lambda_i = O_i.T @ CO_i

        # 确保Lambda_i对称
        Lambda_i = (Lambda_i + Lambda_i.T) / 2

        Lambda_matrices.append(Lambda_i)

        # 计算tau_i (Lambda_i的最小特征值)
        eigenvalues = la.eigvalsh(Lambda_i)
        tau_i = np.min(eigenvalues)
        tau_values.append(tau_i)

    # 构造L(O, Lambda)矩阵
    L_matrix = np.zeros((nd, nd))

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]
        Lambda_i = Lambda_matrices[i]
        tau_i = tau_values[i]

        # 计算O_i Lambda_i O_iᵀ + tau_i (I - O_i O_iᵀ)
        L_block = O_i @ Lambda_i @ O_i.T + tau_i * (np.eye(d) - O_i @ O_i.T)
        L_matrix[i * d:(i + 1) * d, i * d:(i + 1) * d] = L_block

    # L(O, Lambda) = 上述块对角矩阵 - C_star
    L_matrix = L_matrix - C_star

    # 检查L是否半正定: 最小特征值 >= -1e-5 * |最大特征值|
    eigenvalues_L = la.eigvalsh(L_matrix)
    min_eig = np.min(eigenvalues_L)
    max_eig_abs = np.max(np.abs(eigenvalues_L))

    # 如果最小特征值 >= -1e-5 * |最大特征值|，则认为半正定
    is_optimal = (min_eig >= -1e-5 * max_eig_abs)

    return is_optimal


def compute_success_grid_memory_efficient(d, r, N, sigma_low, sigma_high, sigma_number,
                                          n_low, n_high, n_number, constant, filename="success_grid_partial.pkl"):
    """
    增量计算部分观测问题的成功率网格，减少内存使用

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    constant: 常数c，用于计算p = min(c * log(n)/sqrt(n), 1)
    filename: 保存结果的文件名

    返回:
    results: 包含(sigma, n*p, success_rate)的列表
    """
    # 检查是否已有保存的结果
    if os.path.exists(filename):
        print(f"加载已保存的结果: {filename}")
        with open(filename, 'rb') as f:
            results = pickle.load(f)
        return results

    # 检查是否有临时文件
    temp_filename = filename + ".tmp"
    if os.path.exists(temp_filename):
        print(f"找到临时文件: {temp_filename}")
        with open(temp_filename, 'rb') as f:
            results = pickle.load(f)
        print(f"从临时文件加载了 {len(results)} 个已计算的点")
    else:
        results = []

    # 在对数尺度上生成σ和n的值
    sigma_log_low = np.log10(sigma_low)
    sigma_log_high = np.log10(sigma_high)
    sigma_log_values = np.linspace(sigma_log_low, sigma_log_high, sigma_number)
    sigma_values = 10 ** sigma_log_values

    n_log_low = np.log10(n_low)
    n_log_high = np.log10(n_high)
    n_log_values = np.linspace(n_log_low, n_log_high, n_number)
    n_values = 10 ** n_log_values
    n_values = [int(max(1, np.round(n))) for n in n_values]  # n必须是整数，最小为1
    n_values = sorted(set(n_values))  # 去重并排序

    print(f"计算网格: σ从{sigma_low}到{sigma_high}，共{len(sigma_values)}个点")
    print(f"n从{n_low}到{n_high}，共{len(n_values)}个点(去重后)")
    print(f"常数c = {constant}")

    total_points = len(sigma_values) * len(n_values)

    # 获取已计算的(sigma, n)组合
    computed_pairs = {(r[0], r[3]) for r in results}
    computed_count = len(computed_pairs)
    remaining_points = total_points - computed_count

    print(f"总点数: {total_points}, 已计算: {computed_count}, 剩余: {remaining_points}")

    # 创建进度条
    pbar = tqdm(total=total_points, desc="计算部分观测问题的成功率网格", initial=computed_count)

    try:
        for sigma in sigma_values:
            for n in n_values:
                # 如果这个点已经计算过，跳过
                if (sigma, n) in computed_pairs:
                    continue

                # 计算观测概率 p = min(c * log(n)/sqrt(n), 1)
                p = min(constant * np.log(n) / np.sqrt(n), 1.0)
                n_p = n * p

                success_count = 0

                # 对每个参数组合测试N个问题
                for _ in range(N):
                    # 生成单个问题
                    problem = generate_one_problem_partial(sigma, n, p, d, r)

                    # 使用GPM求解
                    O_solution = GPM(problem)

                    # 判断是否是全局最优
                    C_star = problem['C_star']
                    if is_global_optimal_partial(O_solution, C_star, n, d, r):
                        success_count += 1

                    # 显式删除大矩阵，释放内存
                    del problem['C_star'], problem['C_full'], problem['G_true'], O_solution
                    gc.collect()

                # 计算成功率
                success_ratio = success_count / N

                # 存储结果
                results.append((sigma, n_p, success_ratio,n))

                # 更新已计算的组合
                computed_pairs.add((sigma, n))

                # 更新进度条
                pbar.update(1)

                # 每计算完一个点就保存结果，防止程序崩溃丢失所有数据
                with open(temp_filename, 'wb') as f:
                    pickle.dump(results, f)

        pbar.close()

        # 计算完成后，保存最终结果
        with open(filename, 'wb') as f:
            pickle.dump(results, f)

        # 删除临时文件
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

        print(f"结果已保存到: {filename}")

    except Exception as e:
        print(f"计算过程中发生错误: {e}")
        print("正在保存中间结果到临时文件...")
        pbar.close()
        # 重新抛出异常，让调用者知道计算失败了
        raise

    return results


def visualize_success_grid_partial(results, d, r, constant, save_path="success_heatmap_partial.png"):
    """
    可视化部分观测问题的成功率网格

    参数:
    results: 包含(sigma, n*p, success_rate)的列表
    d, r: 问题参数
    constant: 常数c
    save_path: 保存图像的文件路径
    """
    if not results:
        print("没有结果可可视化")
        return None, None, None

    # 提取数据
    sigmas = np.array([r[0] for r in results])
    n_ps = np.array([r[1] for r in results])
    success_rates = np.array([r[2] for r in results])

    # 获取唯一的σ和n*p值
    unique_sigmas = np.unique(sigmas)
    unique_n_ps = np.unique(n_ps)

    # 创建二维网格
    sigma_grid, n_p_grid = np.meshgrid(unique_sigmas, unique_n_ps)
    success_grid = np.zeros_like(sigma_grid, dtype=float)

    # 填充成功率网格
    for sigma, n_p, rate,n in results:
        i = np.where(unique_sigmas == sigma)[0][0]
        j = np.where(unique_n_ps == n_p)[0][0]
        success_grid[j, i] = rate

    # 创建图形
    plt.figure(figsize=(10, 8))

    # 计算网格边缘（在对数尺度上）
    log_sigma = np.log10(unique_sigmas)
    log_n_p = np.log10(unique_n_ps)

    # 计算边缘位置
    if len(log_sigma) > 1:
        sigma_step = (log_sigma[-1] - log_sigma[0]) / (len(log_sigma) - 1)
        log_sigma_edges = np.concatenate([
            [log_sigma[0] - sigma_step / 2],
            (log_sigma[:-1] + log_sigma[1:]) / 2,
            [log_sigma[-1] + sigma_step / 2]
        ])
    else:
        log_sigma_edges = [log_sigma[0] - 0.5, log_sigma[0] + 0.5]

    if len(log_n_p) > 1:
        n_p_step = (log_n_p[-1] - log_n_p[0]) / (len(log_n_p) - 1)
        log_n_p_edges = np.concatenate([
            [log_n_p[0] - n_p_step / 2],
            (log_n_p[:-1] + log_n_p[1:]) / 2,
            [log_n_p[-1] + n_p_step / 2]
        ])
    else:
        log_n_p_edges = [log_n_p[0] - 0.5, log_n_p[0] + 0.5]

    # 绘制热力图
    im = plt.pcolormesh(
        log_n_p_edges,
        log_sigma_edges,
        success_grid.T,
        cmap='gray_r',
        shading='flat',
        vmin=0,
        vmax=1,
        edgecolors='none'
    )

    # 添加颜色条
    cbar = plt.colorbar(im, label='Success Rate (%)')
    cbar.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
    cbar.set_ticklabels(['0', '25', '50', '75', '100'])

    # 添加蓝线: σ = sqrt(n*p)  => log(σ) = 0.5 * log(n*p)
    n_p_line = np.logspace(np.log10(min(unique_n_ps)), np.log10(max(unique_n_ps)), 100)
    sigma_line = np.sqrt(n_p_line)
    plt.plot(np.log10(n_p_line), np.log10(sigma_line), 'b-', linewidth=2.5,
             label=r'$\sigma = \sqrt{np}$')

    # 设置坐标轴
    plt.xlabel('$np$ (log-scale)', fontsize=12)
    plt.ylabel('Noise level $\sigma$ (log-scale)', fontsize=12)
    plt.title(f'Partial Observation: GPM Convergence Rate\n(d={d}, r={r}, c={constant})', fontsize=14)

    # 设置X轴刻度
    min_n_p, max_n_p = min(unique_n_ps), max(unique_n_ps)

    # 生成合适的X轴刻度
    n_p_ticks = []
    tick = 1
    while tick < min_n_p:
        tick *= 10
    while tick <= max_n_p * 1.5:
        n_p_ticks.append(tick)
        tick *= 10

    plt.xticks(np.log10(n_p_ticks), [f'{t:.0f}' for t in n_p_ticks])

    # 设置Y轴刻度
    sigma_ticks = []
    tick = 0.1
    while tick < min(unique_sigmas):
        tick *= 2
    while tick <= max(unique_sigmas) * 1.5:
        sigma_ticks.append(tick)
        tick *= 2

    plt.yticks(np.log10(sigma_ticks), [f'{t:.1f}' if t < 1 else f'{int(t)}' for t in sigma_ticks])

    # 设置坐标轴范围
    plt.xlim(min(log_n_p_edges), max(log_n_p_edges))
    plt.ylim(min(log_sigma_edges), max(log_sigma_edges))

    plt.legend(fontsize=12, loc='upper left')
    plt.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    plt.tight_layout()

    # 保存图像
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图像已保存到: {save_path}")

    plt.show()

    return success_grid, unique_sigmas, unique_n_ps


def compute_and_visualize_partial_memory_efficient(d, r, N, constant, sigma_low, sigma_high, sigma_number,
                                                   n_low, n_high, n_number):
    """
    增量计算并可视化部分观测问题的成功率网格

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    constant: 常数c，用于计算p = min(c * log(n)/sqrt(n), 1)
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    resume: 是否从之前的计算继续
    """
    # 创建文件名
    filename = f"success_grid_partial_d{d}_r{r}_N{N}_s{sigma_number}_n{n_number}_sigma_high{sigma_high}_n_high{n_high}.pkl"
    results = compute_success_grid_memory_efficient(
        d, r, N, sigma_low, sigma_high, sigma_number,
        n_low, n_high, n_number, constant, filename
    )
    save_path = f"success_heatmap_partial_d{d}_r{r}_c{constant}.png"
    success_grid, sigmas, n_ps = visualize_success_grid_partial(results, d, r, constant, save_path)
    return results, success_grid, sigmas, n_ps


# 主程序
if __name__ == "__main__":
    d = 10
    r = 3
    N = 10
    constant = 2
    sigma_low = 0.1
    sigma_high = 40
    sigma_number = 50
    n_low = 10
    n_high = 1000
    n_number = 100

    # 计算并可视化
    results, success_grid, sigmas, n_ps = compute_and_visualize_partial_memory_efficient(
        d, r, N, constant, sigma_low, sigma_high, sigma_number,
        n_low, n_high, n_number
    )

    print("\n计算完成!")
