import numpy as np
import scipy.linalg as la
from scipy.sparse.linalg import eigs
from scipy.linalg import svd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import gc
from tqdm import tqdm


def generate_one_problem(sigma, n, d, r):
    """
    生成单个Synchronization problem实例，避免存储多个问题

    参数:
    sigma: 噪声强度
    n: 组数
    d: 每个矩阵的维度
    r: 秩

    返回:
    problem: 包含单个问题的字典
    """
    nd = n * d

    # 生成真实的旋转矩阵G
    G = np.zeros((nd, r))
    for i in range(n):
        # 对每个块生成随机正交矩阵
        G_i = np.random.randn(d, r)
        U, s, Vt = svd(G_i, full_matrices=False)
        G_i = U @ Vt  # 投影到Stiefel流形
        G[i * d:(i + 1) * d, :] = G_i

    # 生成噪声矩阵W (对称，上半部分i.i.d.标准正态)
    W_upper = np.random.randn(nd, nd)
    W = np.triu(W_upper) + np.triu(W_upper, 1).T  # 对称化

    # 生成观测矩阵C = GG^T + sigma * W
    C = G @ G.T + sigma * W

    return {
        'C': C,
        'G': G,
        'W': W,
        'n': n,
        'd': d,
        'r': r,
        'sigma': sigma
    }


def project_stiefel(X):
    """
    将矩阵投影到Stiefel流形上

    参数:
    X: 输入矩阵 (d × r)

    返回:
    Q: 投影后的正交矩阵 (d × r), Q^T Q = I_r
    """
    U, s, Vt = svd(X, full_matrices=False)
    return U @ Vt


def GPM(problem, max_iter=400, tol=1e-8):
    """
    使用Generalized Power Method求解OTSM问题

    参数:
    problem: 问题字典
    max_iter: 最大迭代次数
    tol: 收敛容忍度

    返回:
    O: 解矩阵 (nd × r)
    convergence_info: 收敛信息字典
    """
    C = problem['C']
    n = problem['n']
    d = problem['d']
    r = problem['r']
    nd = n * d

    # 步骤1: 计算C的top-r特征向量

    eigenvalues, eigenvectors = la.eigh(C)
    eigenvectors = eigenvectors[:, -r:]  # 取最大的r个特征值对应的特征向量

    # 归一化使得tilde{G}^T tilde{G} = nI_r
    scale = np.sqrt(n) / la.norm(eigenvectors, axis=0)
    tilde_G = eigenvectors * scale

    # 步骤2: 初始化
    O = np.zeros_like(tilde_G)
    for i in range(n):
        X_i = tilde_G[i * d:(i + 1) * d, :]
        O[i * d:(i + 1) * d, :] = project_stiefel(X_i)

    # 步骤3: 迭代
    prev_O = O.copy()
    for t in range(max_iter):
        # 计算CG^t
        Y = C @ O

        # 投影每个块到Stiefel流形
        for i in range(n):
            Y_i = Y[i * d:(i + 1) * d, :]
            O[i * d:(i + 1) * d, :] = project_stiefel(Y_i)

        # 检查收敛性
        diff_norm = la.norm(O - prev_O, 'fro') / np.sqrt(n)
        if diff_norm < tol:
            break

        prev_O = O.copy()

    convergence_info = {
        'iterations': t + 1,
        'final_diff': diff_norm,
        'converged': diff_norm < tol
    }

    return O, convergence_info


def is_global_optimal(O, C, n, d, r):
    """
    判断解O是否是全局最优解

    参数:
    O: 候选解 (nd × r)
    C: 观测矩阵 (nd × nd)
    n, d, r: 问题参数

    返回:
    is_optimal: 是否是全局最优
    L_min_eig: L矩阵的最小特征值
    L_max_eig: L矩阵的最大特征值
    """
    nd = n * d

    # 计算拉格朗日乘子Λ_i
    Lambda_blocks = []
    tau_values = []

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]

        # 计算(CO)_i
        CO_i = np.zeros((d, r))
        for j in range(n):
            C_ij = C[i * d:(i + 1) * d, j * d:(j + 1) * d]
            O_j = O[j * d:(j + 1) * d, :]
            CO_i += C_ij @ O_j

        # Λ_i = O_i^T (CO)_i
        Lambda_i = O_i.T @ CO_i
        # 确保对称性（数值误差可能导致轻微不对称）
        Lambda_i = (Lambda_i + Lambda_i.T) / 2

        Lambda_blocks.append(Lambda_i)

        # 计算τ_i (Λ_i的最小特征值)
        eigvals = la.eigvalsh(Lambda_i)
        tau_i = np.min(eigvals)
        tau_values.append(tau_i)

    # 构造L矩阵
    L = np.zeros((nd, nd))

    for i in range(n):
        O_i = O[i * d:(i + 1) * d, :]
        Lambda_i = Lambda_blocks[i]
        tau_i = tau_values[i]

        # 构造第i个对角块
        L_ii = O_i @ Lambda_i @ O_i.T + tau_i * (np.eye(d) - O_i @ O_i.T)
        L[i * d:(i + 1) * d, i * d:(i + 1) * d] = L_ii

    # L = diag(...) - C
    L = L - C

    # 计算L的特征值
    # 为了提高效率，我们只计算最小特征值
    try:
        # 使用eigsh计算最小特征值，避免计算所有特征值
        from scipy.sparse.linalg import eigsh
        eigvals_smallest = eigsh(L, k=1, which='SA', return_eigenvectors=False)
        L_min_eig = eigvals_smallest[0]
    except:
        # 如果失败，回退到计算所有特征值
        eigvals = la.eigvalsh(L)
        L_min_eig = np.min(eigvals)

    # 计算最大特征值
    try:
        eigvals_largest = eigsh(L, k=1, which='LA', return_eigenvectors=False)
        L_max_eig = np.abs(eigvals_largest[0])
    except:
        eigvals = la.eigvalsh(L)
        L_max_eig = np.max(np.abs(eigvals))

    # 判断是否半正定
    # 如果最小特征值 >= -1e-5 * |最大特征值|，则认为半正定
    is_optimal = L_min_eig >= -1e-5 * L_max_eig

    return is_optimal, L_min_eig, L_max_eig


def compute_success_grid_memory_efficient(d, r, N, sigma_low, sigma_high, sigma_number,
                                          n_low, n_high, n_number, filename="success_grid.pkl"):
    """
    增量计算成功率网格，减少内存使用

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    filename: 保存结果的文件名

    返回:
    results: 包含(sigma, n, success_rate)的列表
    """
    # 检查是否已有保存的结果
    if os.path.exists(filename):
        print(f"加载已保存的结果: {filename}")
        with open(filename, 'rb') as f:
            results = pickle.load(f)
    else:
        results = []

    # 检查是否有临时文件
    temp_filename = filename + ".tmp"
    if os.path.exists(temp_filename):
        print(f"找到临时文件: {temp_filename}")
        with open(temp_filename, 'rb') as f:
            results = pickle.load(f)
        print(f"从临时文件加载了 {len(results)} 个已计算的点")

    # 在对数尺度上生成σ和n的值
    sigma_log_low = np.log10(sigma_low)
    sigma_log_high = np.log10(sigma_high)
    sigma_log_values = np.linspace(sigma_log_low, sigma_log_high, sigma_number)
    sigma_values = 10 ** sigma_log_values

    n_log_low = np.log10(n_low)
    n_log_high = np.log10(n_high)
    n_log_values = np.linspace(n_log_low, n_log_high, n_number)
    n_values = 10 ** n_log_values
    n_values = [int(max(1, np.round(n))) for n in n_values]  # n必须是整数，最小为1
    n_values = sorted(set(n_values))  # 去重并排序

    print(f"计算网格: σ从{sigma_low}到{sigma_high}，共{sigma_number}个点")
    print(f"n从{n_low}到{n_high}，共{len(n_values)}个点(去重后)")
    print(f"σ值(前5个): {sigma_values[:5]}")
    print(f"n值(前5个): {n_values[:5]}")

    total_points = len(sigma_values) * len(n_values)

    # 获取已计算的(sigma, n)组合
    computed_pairs = {(r['sigma'], r['n']) for r in results}
    computed_count = len(computed_pairs)
    remaining_points = total_points - computed_count

    print(f"总点数: {total_points}, 已计算: {computed_count}, 剩余: {remaining_points}")

    # 创建进度条
    pbar = tqdm(total=total_points, desc="计算成功率网格", initial=computed_count)

    try:
        for sigma in sigma_values:
            for n in n_values:
                # 如果这个点已经计算过，跳过
                if (sigma, n) in computed_pairs:
                    # 不更新进度条，因为已经在初始化时设置了初始值
                    continue

                success_count = 0
                details_for_n = []  # 存储当前参数组合的详细信息

                # 对每个参数组合测试N个问题
                for _ in range(N):
                    # 生成单个问题
                    problem = generate_one_problem(sigma, n, d, r)

                    # 使用GPM求解
                    O, conv_info = GPM(problem)

                    # 判断是否全局最优
                    is_optimal, L_min_eig, L_max_eig = is_global_optimal(
                        O, problem['C'], n, d, r)

                    # 存储详细信息
                    detail = {
                        'is_optimal': is_optimal,
                        'L_min_eig': L_min_eig,
                        'L_max_eig': L_max_eig,

                        'convergence_info': conv_info
                    }
                    details_for_n.append(detail)

                    if is_optimal:
                        success_count += 1

                    # 显式删除大矩阵，释放内存
                    del problem['C'], problem['G'], problem['W'], problem, O
                    gc.collect()  # 强制垃圾回收

                # 计算成功率
                success_ratio = success_count / N

                # 保存结果
                result_entry = {
                    'sigma': sigma,
                    'n': n,
                    'success_rate': success_ratio,
                    'details': details_for_n
                }
                results.append(result_entry)

                # 更新已计算的组合
                computed_pairs.add((sigma, n))

                # 更新进度条
                pbar.update(1)

                # 每计算完一个点就保存结果，防止程序崩溃丢失所有数据
                with open(temp_filename, 'wb') as f:
                    pickle.dump(results, f)

        pbar.close()

        # 计算完成后，保存最终结果
        with open(filename, 'wb') as f:
            pickle.dump(results, f)

        # 删除临时文件
        if os.path.exists(temp_filename):
            os.remove(temp_filename)

        print(f"结果已保存到: {filename}")


    except Exception as e:
        print(f"计算过程中发生错误: {e}")
        print("已保存中间结果到临时文件")
        pbar.close()
        raise

    return results


def visualize_success_grid(results, d, r, save_path="success_heatmap.png"):
    """
    可视化成功率网格

    参数:
    results: 包含(sigma, n, success_rate)的列表
    d, r: 问题参数
    save_path: 保存图像的文件路径

    返回:
    success_grid: 成功率网格
    sigmas: 唯一的sigma值
    ns: 唯一的n值
    """
    if not results:
        print("没有结果可可视化")
        return None, None, None

    # 提取数据
    sigmas = np.array([r['sigma'] for r in results])
    ns = np.array([r['n'] for r in results])
    success_rates = np.array([r['success_rate'] for r in results])

    # 获取唯一的σ和n值
    unique_sigmas = np.unique(sigmas)
    unique_ns = np.unique(ns)

    # 创建二维网格
    sigma_grid, n_grid = np.meshgrid(unique_sigmas, unique_ns)
    success_grid = np.zeros_like(sigma_grid, dtype=float)

    # 填充成功率网格
    for result in results:
        sigma = result['sigma']
        n = result['n']
        rate = result['success_rate']
        i = np.where(unique_sigmas == sigma)[0][0]
        j = np.where(unique_ns == n)[0][0]
        success_grid[j, i] = rate  # 注意: j对应n(行), i对应σ(列)

    # 创建图形
    #plt.figure(figsize=(10, 8))
    plt.figure(figsize=(10, 6))

    # 计算网格边缘（在对数尺度上）
    log_sigma = np.log10(unique_sigmas)
    log_n = np.log10(unique_ns)

    # 计算边缘位置
    if len(log_sigma) > 1:
        sigma_step = (log_sigma[-1] - log_sigma[0]) / (len(log_sigma) - 1)
        log_sigma_edges = np.concatenate([
            [log_sigma[0] - sigma_step / 2],
            (log_sigma[:-1] + log_sigma[1:]) / 2,
            [log_sigma[-1] + sigma_step / 2]
        ])
    else:
        log_sigma_edges = [log_sigma[0] - 0.5, log_sigma[0] + 0.5]

    if len(log_n) > 1:
        n_step = (log_n[-1] - log_n[0]) / (len(log_n) - 1)
        log_n_edges = np.concatenate([
            [log_n[0] - n_step / 2],
            (log_n[:-1] + log_n[1:]) / 2,
            [log_n[-1] + n_step / 2]
        ])
    else:
        log_n_edges = [log_n[0] - 0.5, log_n[0] + 0.5]

    # 绘制热力图
    im = plt.pcolormesh(
        log_n_edges,
        log_sigma_edges,
        success_grid.T,  # 转置以匹配坐标
        cmap='gray_r',
        shading='flat',
        vmin=0,
        vmax=1,
        edgecolors='none'
    )

    # 添加颜色条
    cbar = plt.colorbar(im, label='Success Rate')
    cbar.set_label('Success Rate', fontsize=25)  # 设置字体大小为16
    cbar.set_ticks([0, 0.25, 0.5, 0.75, 1.0])
    cbar.set_ticklabels(['0%', '25%', '50%', '75%', '100%'],fontsize=25)


    # # 添加蓝线: σ = sqrt(n)  => log(σ) = 0.5 * log(n)
    # n_line = np.logspace(np.log10(min(unique_ns)), np.log10(max(unique_ns)), 100)
    # sigma_line = np.sqrt(n_line)
    # plt.plot(np.log10(n_line), np.log10(sigma_line), 'b-', linewidth=2.5, label=r'$\sigma = \sqrt{n}$')

    # 添加蓝线: σ = sqrt(n)  => log(σ) = 0.5 * log(n)
    n_line = np.logspace(np.log10(min(unique_ns)), np.log10(max(unique_ns)), 100)
    sigma_line = np.sqrt(n_line)
    plt.plot(np.log10(n_line), np.log10(sigma_line) - 0.8, 'r-', linewidth=3.0, label=r'$\sigma = 0.158\sqrt{n}$')
    # 设置图例字体大小
    plt.legend(prop={'size': 25})  # 设置字体大小和粗细

    # 设置坐标轴
    plt.xlabel('Number of phases n (log-scale)', fontsize=25)
    plt.ylabel('Noise level $\sigma$ (log-scale)', fontsize=25)
    # plt.title(f'How often does GPM converge to a global optimum?\n(d={d}, r={r})', fontsize=18)

    # 设置刻度
    plt.xticks(np.log10([10, 100, 1000]), ['10', '100', '1000'])

    # 设置合适的Y轴刻度
    sigma_ticks = []
    current = 0.1
    while current <= max(unique_sigmas) * 1.1:
        sigma_ticks.append(current)
        current *= 2

    plt.yticks(np.log10(sigma_ticks), [f'{t:.1f}' for t in sigma_ticks])

    plt.tick_params(axis='both', labelsize=25)  # 同时设置x轴和y轴

    # 设置坐标轴范围
    plt.xlim(min(log_n_edges), max(log_n_edges))
    plt.ylim(min(log_sigma_edges), max(log_sigma_edges))


    plt.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    plt.tight_layout()

    # 保存图像
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"图像已保存到: {save_path}")

    plt.show()

    return success_grid, unique_sigmas, unique_ns


def compute_and_visualize_memory_efficient(d, r, N, sigma_low, sigma_high, sigma_number,
                                           n_low, n_high, n_number, resume=True):
    """
    增量计算并可视化成功率网格

    参数:
    d: 每个矩阵的行维度
    r: 每个矩阵的列维度
    N: 每个参数组合测试的问题数量
    sigma_low, sigma_high: σ的范围
    sigma_number: σ的取点数量
    n_low, n_high: n的范围
    n_number: n的取点数量
    resume: 是否从之前的计算继续
    """
    # 创建文件名
    filename = f"success_grid_d{d}_r{r}_N{N}_s{sigma_number}_n{n_number}.pkl"

    results = compute_success_grid_memory_efficient(
        d, r, N, sigma_low, sigma_high, sigma_number,
        n_low, n_high, n_number, filename
    )

    # 可视化
    save_path = f"success_heatmap_d{d}_r{r}.png"
    success_grid, sigmas, ns = visualize_success_grid(results, d, r, save_path)

    return results, success_grid, sigmas, ns


# 主程序
if __name__ == "__main__":
    d = 10
    r = 3
    N = 10
    sigma_low = 0.1
    sigma_high = 100
    sigma_number = 100
    n_low = 10
    n_high = 1000
    n_number = 200

    # 计算并可视化
    results, success_grid, sigmas, ns = compute_and_visualize_memory_efficient(
        d, r, N, sigma_low, sigma_high, sigma_number,
        n_low, n_high, n_number, resume=True
    )

    print("\n计算完成!")