import numpy as np
import pickle
import os
from scipy.linalg import svd
import matplotlib.pyplot as plt
import numpy.linalg as la


def verify_d_F(n_low, n_high, n_num, constant_p, constant_sigma, d, r, N):
    """
    验证理论边界: d_F(G^∞, G) ≲ (1+σ)/√p

    参数:
    n_low, n_high: n的取值范围
    n_num: n的采样数量
    constant_p: p公式中的常数
    constant_sigma: sigma公式中的常数
    d, r: 矩阵维度
    N: 每个n生成的OSTM问题数量
    """

    # 生成文件名（修正：包含constant_sigma参数）
    params_str = f"{n_low}_{n_high}_{n_num}_{constant_p}_{constant_sigma}_{d}_{r}_{N}"
    tempt_file = f"{params_str}.tempt.pkl"
    final_file = f"{params_str}.final.pkl"

    # 检查是否已有最终结果
    if os.path.exists(final_file):
        print(f"找到最终结果文件 {final_file}，直接读取...")
        with open(final_file, 'rb') as f:
            results = pickle.load(f)

        # 绘制结果
        plot_results(results)
        return results

    # 检查是否有中间结果
    if os.path.exists(tempt_file):
        print(f"找到中间结果文件 {tempt_file}，继续计算...")
        with open(tempt_file, 'rb') as f:
            results = pickle.load(f)
    else:
        print("未找到结果文件，开始新计算...")
        # 初始化结果字典
        results = {
            'n_values': [],
            'avg_dist_infinity': [],
            'avg_dist_init': [],
            'params': {
                'n_low': n_low,
                'n_high': n_high,
                'n_num': n_num,
                'constant_p': constant_p,
                'constant_sigma': constant_sigma,
                'd': d,
                'r': r,
                'N': N
            }
        }

    # 生成n值（使用对数均匀分布可能更好，但这里用线性均匀）
    if n_num == 1:
        n_values = [n_low]
    else:
        # 可以选择对数均匀或线性均匀
        n_values = np.linspace(n_low, n_high, n_num, dtype=int)
        # 或者使用对数均匀：n_values = np.logspace(np.log10(n_low), np.log10(n_high), n_num, dtype=int)

    # 只计算尚未计算的n值
    computed_n = set(results['n_values'])

    for i, n in enumerate(n_values):
        if n in computed_n:
            print(f"n={n} 已计算，跳过...")
            continue

        print(f"计算 n={n} ({i + 1}/{len(n_values)})")

        # 计算p和sigma
        p = constant_p * np.log(n) / np.sqrt(n)
        # 确保p在[0,1]范围内
        p = max(0.001, min(0.99, p))

        # sigma = constant_sigma * np.sqrt(n * p / np.log(n))
        sigma = constant_sigma * np.sqrt(n * p )

        # 生成N个问题并求解
        dist_infinity_list = []
        dist_init_list = []

        # 生成问题
        problems = generate_problem_partial(sigma, n, p, d, r, N)

        for problem_idx, problem in enumerate(problems):
            print(f"  处理问题 {problem_idx + 1}/{N}")

            # 获取真实G
            G_true = problem['G_true']

            # 使用GPM求解
            # 首先，获取初始解G^0
            C_star = problem['C_star']
            nd = n * d

            # 计算C_star的前r个特征向量
            eigenvalues, eigenvectors = np.linalg.eigh(C_star)
            # 取最大的r个特征值对应的特征向量
            idx = np.argsort(eigenvalues)[::-1][:r]
            tilde_G = eigenvectors[:, idx]

            # 确保tilde_Gᵀtilde_G = nI_r
            scale = np.sqrt(n*r) / np.linalg.norm(tilde_G, 'fro')
            tilde_G = tilde_G * scale

            # 初始化解G^0
            G0 = project_stiefel_block(tilde_G, n, d, r)

            # 计算GPM迭代得到的G^∞
            G_infinity = GPM(problem, max_iter=400, tol=1e-8)

            # 计算距离
            dist_infinity, _ = distance_F(G_infinity, G_true)
            dist_init, _ = distance_F(G0, G_true)

            # 乘以缩放因子 √p/(1+σ)
            # scale_factor = np.sqrt(p) / (1 + sigma)
            scale_factor= np.sqrt(p) / ( sigma)
            dist_infinity_scaled = dist_infinity * scale_factor
            dist_init_scaled = dist_init * scale_factor

            dist_infinity_list.append(dist_infinity_scaled)
            dist_init_list.append(dist_init_scaled)

        # 对N个问题取平均
        avg_dist_infinity = np.mean(dist_infinity_list)
        avg_dist_init = np.mean(dist_init_list)

        # 保存结果
        results['n_values'].append(n)
        results['avg_dist_infinity'].append(avg_dist_infinity)
        results['avg_dist_init'].append(avg_dist_init)

        print(f"  n={n}: avg_dist_infinity={avg_dist_infinity:.4f}, avg_dist_init={avg_dist_init:.4f}")

        # 保存中间结果到tempt文件
        with open(tempt_file, 'wb') as f:
            pickle.dump(results, f)
        print(f"  中间结果已保存到 {tempt_file}")

    # 所有计算完成，保存最终结果
    with open(final_file, 'wb') as f:
        pickle.dump(results, f)
    print(f"最终结果已保存到 {final_file}")

    # 删除tempt文件
    if os.path.exists(tempt_file):
        os.remove(tempt_file)
        print(f"已删除中间文件 {tempt_file}")

    # 绘制结果
    plot_results(results)

    return results


def plot_results(results):
    """绘制结果曲线（修正：将两条曲线放在同一幅图上）"""
    n_values = results['n_values']
    avg_dist_infinity = results['avg_dist_infinity']
    avg_dist_init = results['avg_dist_init']

    # 排序，确保n_values从小到大
    sorted_indices = np.argsort(n_values)
    n_values_sorted = np.array(n_values)[sorted_indices]
    avg_dist_infinity_sorted = np.array(avg_dist_infinity)[sorted_indices]
    avg_dist_init_sorted = np.array(avg_dist_init)[sorted_indices]

    # 创建图形（修正：将两条曲线放在同一幅图上）
    plt.figure(figsize=(12, 8))

    # 在同一幅图上绘制两条曲线
    plt.plot(n_values_sorted, avg_dist_infinity_sorted, 'b-o', linewidth=2, markersize=8,
             label=r'$\frac{\sqrt{p}}{\sigma} \cdot d_F(\widehat{G}^\infty,G^\star)$')
    plt.plot(n_values_sorted, avg_dist_init_sorted, 'g-s', linewidth=2, markersize=8,
             label=r'$\frac{\sqrt{p}}{\sigma} \cdot d_F(G^0,G^\star)$')

    plt.xlabel('n', fontsize=25)
    plt.ylabel(r'$d_F \cdot \frac{\sqrt{p}}{\sigma}$', fontsize=25)
    # plt.title(r'Verify Bound: $d_F(G^\infty, G) \lesssim \frac{1+\sigma}{\sqrt{p}}$', fontsize=20)
    plt.grid(True, alpha=0.3)
    # plt.axhline(y=1.0, color='r', linestyle='--', alpha=0.7, label='常数边界参考线')
    plt.legend(fontsize=25, loc='upper right' )
    plt.tick_params(axis='both', labelsize=25)  # 同时设置x轴和y轴

    # # 如果n值跨度大，使用对数刻度
    # if max(n_values_sorted) / min(n_values_sorted) > 10:
    #     plt.xscale('log')
    save_path=f"{n_low}_{n_high}_{n_num}_{constant_p}_{constant_sigma}_{d}_{r}_{N}.png"

    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"图像已保存到: {save_path}")

    plt.tight_layout()
    plt.show()


    # 打印参数信息
    params = results.get('params', {})
    print("\n实验参数:")
    for key, value in params.items():
        print(f"  {key}: {value}")


# 添加缺失的函数定义（已在问题中给出，但为了完整性包含在此）
def distance_F(X, G):
    """
    计算两个矩阵在正交模糊性下的Frobenius距离

    距离定义为: d_F(X, G) = min_{Q ∈ O(r)} ||X - GQ||_F

    参数:
    X: 估计矩阵 (n*d × r)
    G: 真实矩阵 (n*d × r)

    返回:
    distance: 最小距离
    Q_opt: 最优正交矩阵
    """
    # 计算X^T G
    XtG = X.T @ G

    # 对XtG进行SVD
    U, S, Vt = svd(XtG, full_matrices=False)

    # 最优正交矩阵 Q = U V^T
    Q_opt = U @ Vt

    # 计算距离
    distance = np.linalg.norm(X - G @ Q_opt.T, 'fro')

    return distance, Q_opt.T


def GPM(problem, max_iter=400, tol=1e-8):
    """
    使用广义幂方法求解部分观测OTSM问题

    参数:
    problem: 包含问题数据的字典
    max_iter: 最大迭代次数
    tol: 收敛容忍度

    返回:
    O: 求解得到的矩阵 (n*d × r)
    """
    C_star = problem['C_star']
    n = problem['n']
    d = problem['d']
    r = problem['r']
    nd = n * d

    # 步骤1: 计算C_star的前r个特征向量
    eigenvalues, eigenvectors = la.eigh(C_star)
    # 取最大的r个特征值对应的特征向量
    idx = np.argsort(eigenvalues)[::-1][:r]
    tilde_G = eigenvectors[:, idx]

    # 确保tilde_Gᵀtilde_G = nI_r
    scale = np.sqrt(n) / la.norm(tilde_G, 'fro')
    tilde_G = tilde_G * scale

    # 步骤2: 初始化
    O_prev = project_stiefel_block(tilde_G, n, d, r)

    # 步骤3: 迭代
    for t in range(max_iter):
        # 计算C_star * O_prev
        CG = C_star @ O_prev

        # 投影到Stiefel流形
        O_next = project_stiefel_block(CG, n, d, r)

        # 检查收敛性: ||O_next - O_prev||_F / sqrt(n) < tol
        diff_norm = la.norm(O_next - O_prev, 'fro') / np.sqrt(n)
        if diff_norm < tol:
            break

        O_prev = O_next

    return O_next



def project_stiefel(X):
    """将矩阵投影到Stiefel流形上"""
    U, s, Vt = svd(X, full_matrices=False)
    return U @ Vt


def project_stiefel_block(X, n, d, r):
    """对分块矩阵进行投影，每个块单独投影到Stiefel流形"""
    Y = np.zeros_like(X)
    for i in range(n):
        X_i = X[i * d:(i + 1) * d, :]
        Y_i = project_stiefel(X_i)
        Y[i * d:(i + 1) * d, :] = Y_i
    return Y


def generate_problem_partial(sigma, n, p, d, r, N):
    """
    生成N个部分观测的同步问题

    参数:
    sigma: 噪声强度
    n: 矩阵数量
    p: 观测概率
    d: 每个矩阵的行维度(假设所有d_i相等)
    r: 每个矩阵的列维度
    N: 生成的问题数量

    返回:
    problems: 包含N个问题的列表，每个问题是一个字典，包含:
        - 'C_star': 部分观测矩阵 (nd × nd)
        - 'C_full': 完整观测矩阵 (nd × nd)，用于比较
        - 'G_true': 真实旋转矩阵 (nd × r)
        - 'sigma': 噪声强度
        - 'p': 观测概率
        - 'n', 'd', 'r': 问题参数
    """
    problems = []
    nd = n * d

    for _ in range(N):
        # 生成真实旋转矩阵G
        G_true = np.zeros((nd, r))
        for i in range(n):
            G_i = np.random.randn(d, r)
            U, _, Vt = svd(G_i, full_matrices=False)
            G_i_ortho = U @ Vt
            G_true[i * d:(i + 1) * d, :] = G_i_ortho

        # 生成完整的噪声矩阵W
        W_full = np.zeros((nd, nd))
        upper_tri_indices = np.triu_indices(nd)
        W_full[upper_tri_indices] = np.random.randn(len(upper_tri_indices[0]))
        W_full = W_full + W_full.T - np.diag(np.diag(W_full))

        # 生成完整的观测矩阵C
        C_full = G_true @ G_true.T + sigma * W_full

        # 生成观测掩码矩阵E
        E = np.zeros((n, n))
        for i in range(n):
            for j in range(i, n):
                e_ij = np.random.binomial(1, p)
                E[i, j] = e_ij
                E[j, i] = e_ij  # 确保对称

        # 生成部分观测的C_star
        C_star = np.zeros((nd, nd))
        for i in range(n):
            for j in range(n):
                if E[i, j] == 1:
                    C_star[i * d:(i + 1) * d, j * d:(j + 1) * d] = C_full[i * d:(i + 1) * d, j * d:(j + 1) * d]

        problem = {
            'C_star': C_star,
            'C_full': C_full,  # 保存完整矩阵用于比较
            'G_true': G_true,
            'E': E,  # 保存掩码矩阵
            'sigma': sigma,
            'p': p,
            'n': n,
            'd': d,
            'r': r
        }
        problems.append(problem)

    return problems

# 示例使用
if __name__ == "__main__":
    # 设置参数
    n_low = 10
    n_high = 1000
    n_num = 100  # n值的数量
    constant_p = 2
    constant_sigma = 0.15
    d = 10
    r = 3
    N = 10   # 每个n生成的问题数量

    # 运行验证函数
    results = verify_d_F(n_low, n_high, n_num, constant_p, constant_sigma, d, r, N)

    # 打印结果摘要
    print("\n结果摘要:")
    print("n\tavg_dist_infinity\tavg_dist_init")
    for i, n in enumerate(results['n_values']):
        print(f"{n}\t{results['avg_dist_infinity'][i]:.4f}\t\t\t{results['avg_dist_init'][i]:.4f}")