import torch
import sys
# sys.path.append('./src')
# sys.path.append('../..')
# from utils import *
# from model import *
# from data import *
import numpy as np
import matplotlib.pyplot as plt
import argparse
# from model.GPT2_prenorm_RoPE_onehot import attn_mask
# def get_unmasked_QK_of_models(model, X_input):
#     """
#     计算 model 中每一步的 softmaxQK 矩阵
#     """
#     hidden_state = model.embedding(X_input)
#     dec_self_attn_mask = attn_mask(X_input, device)
#     dec_self_attns = []
#     for layer in model.decoder.layers:
#         hidden_state, _ = layer(hidden_state, dec_self_attn_mask)
#         X_Q = layer.dec_self_attn.W_Q(hidden_state).view(-1, args.seq_len, args.n_heads, args.d_k).transpose(1, 2)
#         X_K = layer.dec_self_attn.W_K(hidden_state).view(-1, args.seq_len, args.n_heads, args.d_k).transpose(1, 2)
#         attn = torch.matmul(X_Q, X_K.transpose(-1, -2)) / np.sqrt(args.d_k)
#         dec_self_attns.append(attn)
#     return dec_self_attns

def cosine_similarity_array(X):
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    return np.dot(X, X.T)

def cosine_similarity_and_singular(X):
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    U_original, s_original, Vt_original = np.linalg.svd(X, full_matrices=False)
    left_singular_original = U_original[:, 0]
    
    cosine_sim = np.dot(X, X.T)
    
    return cosine_sim, left_singular_original

def normalize_vectorgroup(vector_group):
    norms = np.linalg.norm(vector_group,axis=1)
    # print(norms)
    mask = norms > 0
    vector_masked = vector_group[mask]
    
    # print(vector_masked.shape)
    norms = norms[mask]
    # print(vector_masked)
    norms = norms[:, np.newaxis]
    vector_normalized = vector_masked / norms
    return vector_normalized,vector_masked.shape[0]


def seperate_vectors_by_eigenvector(vector_group):
    mask = np.linalg.norm(vector_group,axis=1) > 0
    vector_group = vector_group[mask]
    similarity_matrix = np.dot(vector_group,vector_group.transpose())
    w,v = np.linalg.eig(similarity_matrix)
    index = np.argmax(w)
    tmpeig = v[:,index]
    order_mask = np.argsort(tmpeig)
    
    similarity_matrix = similarity_matrix[order_mask,:]
    similarity_matrix = similarity_matrix[:,order_mask]
    return similarity_matrix,order_mask


def plot_weight_heatmap_eigen(weight, save_path):
    if weight.data.shape[1] != 1:
        weight_normalized,masked_shape = normalize_vectorgroup(weight)
        similarity_matrix,order = seperate_vectors_by_eigenvector(weight_normalized)
        # 创建画布和坐标轴
        fig = plt.figure(frameon=False)  # 关闭画布边框
        ax = plt.Axes(fig, [0, 0, 1, 1])  # 坐标轴铺满整个画布（无留白）
        ax.set_axis_off()  # 关闭坐标轴
        fig.add_axes(ax)
        meshreturn = ax.pcolormesh(similarity_matrix,vmin=-1,vmax=1,cmap='YlGnBu')
        # fig.colorbar(meshreturn)
        # ax.set_xlabel('index',fontsize=18)
        # ax.tick_params(axis = 'both', labelsize = 18)
        # ax.set_ylabel('index',fontsize=18)
        # plt.tight_layout()
        plt.savefig(save_path, dpi = 50, bbox_inches='tight', pad_inches=0, transparent=True)
        plt.close(fig)
    return order

import numpy as np
from scipy.sparse.linalg import svds
from scipy.linalg import svd
import time

def weighted_cosine_similarity(X1, X2, weights1, weights2):
    """计算加权余弦相似度"""
    k1, k2 = X1.shape[1], X2.shape[1]
    
    # 使用矩阵乘法高效计算所有向量对的余弦相似度
    norms1 = np.linalg.norm(X1, axis=0)
    norms2 = np.linalg.norm(X2, axis=0)
    
    # 归一化向量
    X1_norm = X1 / norms1
    X2_norm = X2 / norms2
    
    # 计算余弦相似度矩阵
    cosine_matrix = X1_norm.T @ X2_norm
    
    # 创建权重矩阵
    weight_matrix = np.outer(weights1, weights2)
    
    # 计算加权平均余弦相似度
    total_weight = np.sum(weight_matrix)
    weighted_cosine = np.sum(cosine_matrix * weight_matrix) / total_weight
    
    return weighted_cosine


def compute_partial_svd(matrix, ratio=0.9, use_sparse=True):
    """计算部分SVD，保留指定比例的奇异值方差"""
    m, n = matrix.shape
    k = min(m, n)
    
    if use_sparse and k > 100:  # 对于大矩阵使用稀疏SVD
        # 先计算完整SVD的近似，然后截断
        k_svd = min(200, k)  # 计算前200个奇异值
        U, s, Vt = svds(matrix, k=k_svd)
        
        # 按奇异值大小排序（svds返回的顺序可能不是递减的）
        idx = np.argsort(s)[::-1]
        s = s[idx]
        U = U[:, idx]
        Vt = Vt[idx, :]
    else:
        # 对于小矩阵使用完整SVD
        U, s, Vt = svd(matrix, full_matrices=False)
    
    # 计算需要保留的奇异值数量
    total_variance = np.sum(s**2)
    cumulative_variance = np.cumsum(s**2) / total_variance
    k_keep = np.argmax(cumulative_variance >= ratio) + 1
    
    return U[:, :k_keep], s[:k_keep], Vt[:k_keep, :]


def weighted_singular_vector_cosine(A, B, variance_ratio=0.9, use_sparse_svd=True):
    """
    计算两个矩阵前90%奇异值对应的奇异向量之间的余弦值，并按照奇异值加权
    
    参数:
    A, B: 输入矩阵
    variance_ratio: 保留的奇异值方差比例 (默认0.9)
    use_sparse_svd: 是否使用稀疏SVD加速计算
    
    返回:
    weighted_cosine: 加权余弦相似度
    left_cosine: 左奇异向量加权余弦相似度  
    right_cosine: 右奇异向量加权余弦相似度
    """

    # 计算两个矩阵的部分SVD
    U1, s1, Vt1 = compute_partial_svd(A, variance_ratio, use_sparse_svd)
    U2, s2, Vt2 = compute_partial_svd(B, variance_ratio, use_sparse_svd)
    
    V1 = Vt1.T
    V2 = Vt2.T
    
    
    # 计算左右奇异向量的加权余弦相似度
    left_cosine = weighted_cosine_similarity(U1, U2, s1, s2)
    right_cosine = weighted_cosine_similarity(V1, V2, s1, s2)
    
    # 总体加权余弦相似度
    weighted_cosine = (left_cosine + right_cosine) / 2
    
    return weighted_cosine, left_cosine, right_cosine

def optimized_weighted_cosine(A, B, variance_ratio=0.9):
    """
    进一步优化的版本，减少内存使用
    """
    m, n = A.shape
    
    def fast_partial_svd(matrix, ratio=0.9):
        """快速部分SVD计算"""
        k = min(matrix.shape)
        k_compute = min(3 * int(k * ratio), k)  # 计算比需要更多的奇异值
        
        if k > 500:  # 大矩阵使用随机SVD
            from sklearn.utils.extmath import randomized_svd
            U, s, Vt = randomized_svd(matrix, n_components=k_compute, random_state=42)
        else:
            U, s, Vt = svd(matrix, full_matrices=False)
            U, s, Vt = U[:, :k_compute], s[:k_compute], Vt[:k_compute, :]
        
        # 截断到指定方差比例
        total_variance = np.sum(s**2)
        cumulative_variance = np.cumsum(s**2) / total_variance
        k_keep = np.argmax(cumulative_variance >= ratio) + 1
        
        return U[:, :k_keep], s[:k_keep], Vt[:k_keep, :]
    
    # 计算SVD
    U1, s1, Vt1 = fast_partial_svd(A, variance_ratio)
    U2, s2, Vt2 = fast_partial_svd(B, variance_ratio)
    
    def efficient_weighted_cosine(U1, U2, s1, s2):
        """高效计算加权余弦相似度"""
        # 归一化向量并加权
        U1_weighted = (U1 / np.linalg.norm(U1, axis=0, keepdims=True)) * s1
        U2_weighted = (U2 / np.linalg.norm(U2, axis=0, keepdims=True)) * s2
        
        # 计算加权内积
        inner_product = np.sum(U1_weighted * U2_weighted)
        
        # 计算权重和
        weight_sum = np.sum(s1[:, None] * s2[None, :])
        
        return inner_product / weight_sum
    
    left_cosine = efficient_weighted_cosine(U1, U2, s1, s2)
    right_cosine = efficient_weighted_cosine(Vt1.T, Vt2.T, s1, s2)
    
    return (left_cosine + right_cosine) / 2, left_cosine, right_cosine

# 测试代码
if __name__ == "__main__":
    # 生成测试数据
    np.random.seed(42)
    m, n = 1000, 500
    
    # 创建两个相似的矩阵
    A = np.random.randn(m, n)
    B = A + 0.1 * np.random.randn(m, n)  # B是A的噪声版本
    
    print("矩阵形状:", A.shape)
    
    # 方法1: 标准方法
    start_time = time.time()
    result1 = weighted_singular_vector_cosine(A, B)
    time1 = time.time() - start_time
    
    # 方法2: 优化方法
    start_time = time.time()
    result2 = optimized_weighted_cosine(A, B)
    time2 = time.time() - start_time
    
    print(f"\n标准方法结果:")
    print(f"总加权余弦相似度: {result1[0]:.6f}")
    print(f"左奇异向量相似度: {result1[1]:.6f}")
    print(f"右奇异向量相似度: {result1[2]:.6f}")
    print(f"计算时间: {time1:.4f}秒")
    
    print(f"\n优化方法结果:")
    print(f"总加权余弦相似度: {result2[0]:.6f}")
    print(f"左奇异向量相似度: {result2[1]:.6f}")
    print(f"右奇异向量相似度: {result2[2]:.6f}")
    print(f"计算时间: {time2:.4f}秒")
    
    # 性能对比
    print(f"\n优化方法比标准方法快 {time1/time2:.2f} 倍")