import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from scipy.linalg import svd
import scipy.linalg as la

def project_stiefel(X):
    """将矩阵投影到Stiefel流形上"""
    U, s, Vt = svd(X, full_matrices=False)
    return U @ Vt


def project_stiefel_block(X, n, d, r):
    """对分块矩阵进行投影，每个块单独投影到Stiefel流形"""
    Y = np.zeros_like(X)
    for i in range(n):
        X_i = X[i * d:(i + 1) * d, :]
        Y_i = project_stiefel(X_i)
        Y[i * d:(i + 1) * d, :] = Y_i
    return Y

def generate_problem_partial(sigma, n, p, d, r, N):
    """
    生成N个部分观测的同步问题

    参数:
    sigma: 噪声强度
    n: 矩阵数量
    p: 观测概率
    d: 每个矩阵的行维度(假设所有d_i相等)
    r: 每个矩阵的列维度
    N: 生成的问题数量

    返回:
    problems: 包含N个问题的列表，每个问题是一个字典，包含:
        - 'C_star': 部分观测矩阵 (nd × nd)
        - 'C_full': 完整观测矩阵 (nd × nd)，用于比较
        - 'G_true': 真实旋转矩阵 (nd × r)
        - 'sigma': 噪声强度
        - 'p': 观测概率
        - 'n', 'd', 'r': 问题参数
    """
    problems = []
    nd = n * d

    for _ in range(N):
        # 生成真实旋转矩阵G
        G_true = np.zeros((nd, r))
        for i in range(n):
            G_i = np.random.randn(d, r)
            U, _, Vt = svd(G_i, full_matrices=False)
            G_i_ortho = U @ Vt
            G_true[i * d:(i + 1) * d, :] = G_i_ortho

        # 生成完整的噪声矩阵W
        W_full = np.zeros((nd, nd))
        upper_tri_indices = np.triu_indices(nd)
        W_full[upper_tri_indices] = np.random.randn(len(upper_tri_indices[0]))
        W_full = W_full + W_full.T - np.diag(np.diag(W_full))

        # 生成完整的观测矩阵C
        C_full = G_true @ G_true.T + sigma * W_full

        # 生成观测掩码矩阵E
        E = np.zeros((n, n))
        for i in range(n):
            for j in range(i, n):
                e_ij = np.random.binomial(1, p)
                E[i, j] = e_ij
                E[j, i] = e_ij  # 确保对称

        # 生成部分观测的C_star
        C_star = np.zeros((nd, nd))
        for i in range(n):
            for j in range(n):
                if E[i, j] == 1:
                    C_star[i * d:(i + 1) * d, j * d:(j + 1) * d] = C_full[i * d:(i + 1) * d, j * d:(j + 1) * d]

        problem = {
            'C_star': C_star,
            'C_full': C_full,  # 保存完整矩阵用于比较
            'G_true': G_true,
            'E': E,  # 保存掩码矩阵
            'sigma': sigma,
            'p': p,
            'n': n,
            'd': d,
            'r': r
        }
        problems.append(problem)

    return problems

def distance_F(X, G):
    """
    计算两个矩阵在正交模糊性下的Frobenius距离

    距离定义为: d_F(X, G) = min_{Q ∈ O(r)} ||X - GQ||_F

    参数:
    X: 估计矩阵 (n*d × r)
    G: 真实矩阵 (n*d × r)

    返回:
    distance: 最小距离
    Q_opt: 最优正交矩阵
    """
    # 计算X^T G
    XtG = X.T @ G

    # 对XtG进行SVD
    U, S, Vt = svd(XtG, full_matrices=False)

    # 最优正交矩阵 Q = U V^T
    Q_opt = U @ Vt

    # 计算距离
    distance = np.linalg.norm(X - G @ Q_opt.T, 'fro')

    return distance, Q_opt.T


def verify_linear_convergence(n, sigma, p_list, t_infty, t_graph, d=5, r=3, N=1):
    """
    验证GPM算法的线性收敛性

    参数:
    n: 矩阵数量
    sigma: 噪声强度
    p_list: 观测概率列表
    t_infty: 总迭代次数，用于得到G^∞
    t_graph: 要绘制的迭代步数
    d: 每个矩阵的行维度(假设所有d_i相等)
    r: 每个矩阵的列维度
    N: 每个p生成的问题数量(这里取1，因为您要求每个p生成一个)

    返回:
    保存数据到pkl文件并绘制收敛图
    """
    # 检查pkl文件是否存在
    data_file = f'convergence_data_n{n}_sigma{sigma}_tinfty{t_infty}.pkl'

    if os.path.exists(data_file):
        print(f"Loading data from {data_file}")
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
    else:
        print(f"Data file {data_file} not found. Computing convergence data...")
        data = {}

        # 对每个p进行迭代
        for p in p_list:
            print(f"Processing p = {p}")

            # 生成问题
            problems = generate_problem_partial(sigma, n, p, d, r, N)
            problem = problems[0]  # 只取第一个问题

            # 修改GPM函数，移除tol停机准则
            C_star = problem['C_star']
            nd = n * d

            # 步骤1: 计算C_star的前r个特征向量
            eigenvalues, eigenvectors = la.eigh(C_star)
            # 取最大的r个特征值对应的特征向量
            idx = np.argsort(eigenvalues)[::-1][:r]
            tilde_G = eigenvectors[:, idx]

            # 确保tilde_Gᵀtilde_G = nI_r
            scale = np.sqrt(n) / la.norm(tilde_G, 'fro')
            tilde_G = tilde_G * scale

            # 步骤2: 初始化
            O_prev = project_stiefel_block(tilde_G, n, d, r)

            # 存储前t_graph个迭代点
            G_t_list = [O_prev.copy()]

            # 迭代t_infty次
            for t in range(t_infty):
                # 计算C_star * O_prev
                CG = C_star @ O_prev

                # 投影到Stiefel流形
                O_next = project_stiefel_block(CG, n, d, r)

                # 存储前t_graph个迭代点
                if t < t_graph - 1:  # 减1是因为已经存储了初始点
                    G_t_list.append(O_next.copy())

                O_prev = O_next

            # 最终迭代结果作为G^∞
            G_infty = O_next

            # 计算每个G_t与G^∞的距离
            distances = []
            for G_t in G_t_list:
                dist, _ = distance_F(G_t, G_infty)
                distances.append(dist)

            data[p] = {
                'distances': distances,
                'G_t_list': G_t_list,
                'G_infty': G_infty,
                'problem': problem
            }

        # 保存数据到pkl文件
        with open(data_file, 'wb') as f:
            pickle.dump(data, f)
        print(f"Data saved to {data_file}")

    # 绘制收敛图
    plt.figure(figsize=(10, 7))

    for p in p_list:
        distances = data[p]['distances']
        t_values = list(range(len(distances)))

        # 绘制收敛曲线
        plt.plot(t_values, distances, marker='o', markersize=8,
                 linewidth=4, label=f'p={p}')

    plt.xlabel('Iteration t', fontsize=25)
    plt.ylabel(r'$d_F(G^t, \hat{G}^\infty)$ (log-scale)', fontsize=25)
    plt.legend(fontsize=25, loc='upper right')
    plt.tick_params(axis='both', labelsize=25)  # 同时设置x轴和y轴
    #plt.title(f'Linear Convergence of GPM: $n={n}$, $\sigma={sigma}$, $t_\infty={t_infty}$', fontsize=16)
    plt.yscale('log')

    # 设置y轴范围，截断小于10^{-12}的数据点
    plt.ylim(bottom=1e-12)

    plt.grid(True, which="both", ls="-", alpha=0.2)

    plt.tight_layout()

    # 保存图像
    plot_file = f'convergence_plot_n{n}_sigma{sigma}_tinfty{t_infty}.png'
    plt.savefig(plot_file, dpi=300)
    print(f"Plot saved to {plot_file}")

    plt.show()

    return data


# 主程序示例
if __name__ == "__main__":
    # 设置参数
    n = 100  # 矩阵数量
    sigma = 0.1  # 噪声强度
    p_list = [0.1,0.2,0.4,0.6,0.8,1]  # 观测概率列表
    t_infty = 400  # 总迭代次数
    t_graph = 50  # 绘制的迭代步数
    d = 10  # 每个矩阵的行维度
    r = 3  # 每个矩阵的列维度

    # 运行验证函数
    data = verify_linear_convergence(n, sigma, p_list, t_infty, t_graph, d, r)