import numpy as np
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

def ExtractUserFeatures(num_users, d, filename):
    X = np.load(filename)

    A1 = X  [:num_users, :]
    u, s, vt = np.linalg.svd(A1)

    u = u[:, :d-1]
    u = normalize(u, axis = 1, norm = 'l2')
    # print(np.linalg.norm(u[0,:]))

    U = np.concatenate((u, np.ones((num_users, 1))), axis = 1) / np.sqrt(2)

    return U
def calculate_min_gap(thetam):
    m = len(thetam)
    min_gap = float('inf')
    for i in range(m):
        for j in range(i + 1, m):
            distance = np.linalg.norm(thetam[i] - thetam[j])
            min_gap = min(min_gap, distance)
    return min_gap

def kmeans_thetas(num_users, d, n_clusters, filename):
    U = np.load(filename)
    kmeans = KMeans(n_clusters=n_clusters).fit(U)
    # print(calculate_min_gap(kmeans.cluster_centers_))
    thetas = np.zeros((num_users, d))
    for i in range(num_users):
        thetas[i] = kmeans.cluster_centers_[kmeans.labels_[i]]
    # thetas = {i:kmeans.cluster_centers_[kmeans.labels_[i]] for i in range(num_users)}
    # print(thetas.shape)
    return thetas, kmeans.cluster_centers_

if __name__ == "__main__":
    U = ExtractUserFeatures(num_users=1000, d=20, filename='OffClusBandit/data/datasets/ml_1000user_1000item.npy')
    print(U.shape)
    np.save('OffClusBandit/data/datasets/ml_1000user_d20', U)

# U = ExtractUserFeatures(num_users=1000, d=20, filename='yelp_2000user_2000item.npy')
# print(U.shape)
# np.save('yelp_1000user_d20', U)

# U = np.load('ml_1000user_d20.npy')
# print(U)
#
# thetas, thetam = kmeans_thetas(num_users=1000, d=20, n_clusters=10, filename='ml_1000user_d20.npy')
# np.save('ml_1000user_d20_m10', thetas)
# np.save('ml_1000user_d20_m10_thetam', thetam)
# thetas = np.load('ml_1000user_d20_m10.npy')
# print(thetas)
# #
# thetas, thetam = kmeans_thetas(num_users=1000, d=20, n_clusters=10, filename='yelp_1000user_d20.npy')
# np.save('yelp_1000user_d20_m10', thetas)
# np.save('yelp_1000user_d20_m10_thetam', thetam)
# thetas = np.load('yelp_1000user_d20_m10.npy')
# print(thetas)



#
# def elbow_method(U, max_clusters=20):
#     inertias = []
#     cluster_range = range(1, max_clusters + 1)
#
#     for k in cluster_range:
#         kmeans = KMeans(n_clusters=k, random_state=42)
#         kmeans.fit(U)
#         inertias.append(kmeans.inertia_)  
#
#     
#     plt.figure(figsize=(8, 5))
#     plt.plot(cluster_range, inertias, 'o-', markersize=6)
#     plt.xlabel('Number of Clusters (k)', fontsize=12)
#     plt.ylabel('Inertia', fontsize=12)
#     plt.title('Elbow Method for Optimal Clusters', fontsize=14)
#     plt.grid()
#     plt.show()
#
#
# # 
# U = np.load('ml_1000user_d50.npy')
# print(U)
# elbow_method(U, max_clusters=50)



def measure_svd_error(original_matrix, reconstructed_matrix, method='frobenius'):
    """
    Measure the error between SVD reconstruction and the original matrix.
    
    Args:
    original_matrix: original reward matrix
    reconstructed_matrix: matrix reconstructed via SVD
    method: error metric to compute
        - 'frobenius': Frobenius norm error
        - 'spectral': Spectral norm error
        - 'relative': Relative error
        - 'rmse': Root Mean Square Error
        - 'mae': Mean Absolute Error
        - 'cosine': Cosine similarity
        - 'all': Return all metrics
    
    Returns:
    error_metrics: dictionary of error metrics
    """
    import numpy as np
    
    # Compute reconstructed matrix
    if reconstructed_matrix is None:
        # If not provided, reconstruct from SVD
        u, s, vt = np.linalg.svd(original_matrix, full_matrices=False)
        # Reconstruct using the first d singular values
        d = min(original_matrix.shape)
        u_d = u[:, :d]
        s_d = s[:d]
        vt_d = vt[:d, :]
        reconstructed_matrix = u_d @ np.diag(s_d) @ vt_d
    
    # Compute error matrix
    error_matrix = original_matrix - reconstructed_matrix
    
    error_metrics = {}
    
    if method == 'all' or method == 'frobenius':
        # Frobenius norm error
        frob_error = np.linalg.norm(error_matrix, 'fro')
        frob_norm_original = np.linalg.norm(original_matrix, 'fro')
        relative_frob_error = frob_error / frob_norm_original
        error_metrics['frobenius_error'] = frob_error
        error_metrics['relative_frobenius_error'] = relative_frob_error
    
    if method == 'all' or method == 'spectral':
        # Spectral norm error
        spectral_error = np.linalg.norm(error_matrix, 2)
        spectral_norm_original = np.linalg.norm(original_matrix, 2)
        relative_spectral_error = spectral_error / spectral_norm_original
        error_metrics['spectral_error'] = spectral_error
        error_metrics['relative_spectral_error'] = relative_spectral_error
    
    if method == 'all' or method == 'rmse':
        # RMSE
        rmse = np.sqrt(np.mean(error_matrix**2))
        error_metrics['rmse'] = rmse
    
    if method == 'all' or method == 'mae':
        # MAE
        mae = np.mean(np.abs(error_matrix))
        error_metrics['mae'] = mae
    
    if method == 'all' or method == 'cosine':
        # Cosine similarity (between row vectors)
        cosine_similarities = []
        for i in range(min(original_matrix.shape[0], reconstructed_matrix.shape[0])):
            orig_row = original_matrix[i, :]
            recon_row = reconstructed_matrix[i, :]
            if np.linalg.norm(orig_row) > 0 and np.linalg.norm(recon_row) > 0:
                cos_sim = np.dot(orig_row, recon_row) / (np.linalg.norm(orig_row) * np.linalg.norm(recon_row))
                cosine_similarities.append(cos_sim)
        
        if cosine_similarities:
            error_metrics['mean_cosine_similarity'] = np.mean(cosine_similarities)
            error_metrics['min_cosine_similarity'] = np.min(cosine_similarities)
    
    if method == 'all' or method == 'relative':
        # Element-wise relative error
        non_zero_mask = original_matrix != 0
        if np.any(non_zero_mask):
            relative_errors = np.abs(error_matrix[non_zero_mask] / original_matrix[non_zero_mask])
            error_metrics['mean_relative_error'] = np.mean(relative_errors)
            error_metrics['max_relative_error'] = np.max(relative_errors)
            error_metrics['median_relative_error'] = np.median(relative_errors)
    
    return error_metrics

def analyze_svd_quality(original_matrix, d_values=None):
    """
    Analyze SVD reconstruction quality under different dimensions d.
    
    Args:
    original_matrix: original matrix
    d_values: list of dimensions to test, default [5, 10, 15, 20, 25, 30, 40, 50]
    
    Returns:
    results: dictionary of error metrics under different dimensions
    """
    if d_values is None:
        d_values = [5, 10, 15, 20, 25, 30, 40, 50]
    
    results = {}
    
    for d in d_values:
        if d > min(original_matrix.shape):
            continue
            
        # Perform SVD
        u, s, vt = np.linalg.svd(original_matrix, full_matrices=False)
        
        # Reconstruct using the first d singular values
        u_d = u[:, :d]
        s_d = s[:d]
        vt_d = vt[:d, :]
        reconstructed_matrix = u_d @ np.diag(s_d) @ vt_d
        
        # Compute errors
        error_metrics = measure_svd_error(original_matrix, reconstructed_matrix, method='all')
        results[d] = error_metrics
        
        print(f"Dimension d={d}:")
        print(f"  Frobenius error: {error_metrics['frobenius_error']:.6f}")
        print(f"  Relative Frobenius error: {error_metrics['relative_frobenius_error']:.6f}")
        print(f"  RMSE: {error_metrics['rmse']:.6f}")
        print(f"  MAE: {error_metrics['mae']:.6f}")
        if 'mean_cosine_similarity' in error_metrics:
            print(f"  Mean cosine similarity: {error_metrics['mean_cosine_similarity']:.6f}")
        print()
    
    return results

def plot_svd_error_analysis(results):
    """
    Plot SVD error analysis figures.
    
    Args:
    results: results returned by analyze_svd_quality
    """
    import matplotlib.pyplot as plt
    
    d_values = list(results.keys())
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Frobenius error
    frob_errors = [results[d]['frobenius_error'] for d in d_values]
    axes[0, 0].plot(d_values, frob_errors, 'o-', markersize=6)
    axes[0, 0].set_xlabel('Dimension d')
    axes[0, 0].set_ylabel('Frobenius error')
    axes[0, 0].set_title('Frobenius error vs Dimension')
    axes[0, 0].grid(True)
    
    # Relative Frobenius error
    rel_frob_errors = [results[d]['relative_frobenius_error'] for d in d_values]
    axes[0, 1].plot(d_values, rel_frob_errors, 'o-', markersize=6, color='red')
    axes[0, 1].set_xlabel('Dimension d')
    axes[0, 1].set_ylabel('Relative Frobenius error')
    axes[0, 1].set_title('Relative Frobenius error vs Dimension')
    axes[0, 1].grid(True)
    
    # RMSE
    rmse_values = [results[d]['rmse'] for d in d_values]
    axes[1, 0].plot(d_values, rmse_values, 'o-', markersize=6, color='green')
    axes[1, 0].set_xlabel('Dimension d')
    axes[1, 0].set_ylabel('RMSE')
    axes[1, 0].set_title('RMSE vs Dimension')
    axes[1, 0].grid(True)
    
    # Cosine similarity
    if 'mean_cosine_similarity' in results[d_values[0]]:
        cosine_sims = [results[d]['mean_cosine_similarity'] for d in d_values]
        axes[1, 1].plot(d_values, cosine_sims, 'o-', markersize=6, color='orange')
        axes[1, 1].set_xlabel('Dimension d')
        axes[1, 1].set_ylabel('Mean cosine similarity')
        axes[1, 1].set_title('Cosine similarity vs Dimension')
        axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()

def measure_extract_user_features_error(original_matrix, U_matrix, method='all'):
    """
    Measure errors between the user feature matrix produced by ExtractUserFeatures and the original reward matrix.
    
    Args:
    original_matrix: original reward matrix (num_users, num_items)
    U_matrix: user feature matrix from ExtractUserFeatures (num_users, d)
    method: error metric option
    
    Returns:
    error_metrics: dictionary of error metrics
    """
    import numpy as np
    
    num_users, num_items = original_matrix.shape
    num_users_u, d = U_matrix.shape
    
    if num_users != num_users_u:
        raise ValueError(f"User count mismatch: original {num_users} vs U {num_users_u}")
    
    # Method 1: compare similarity between original matrix and U via cosine between rows
    cosine_similarities = []
    for i in range(num_users):
        orig_user_vector = original_matrix[i, :]
        u_user_vector = U_matrix[i, :]
        
        if np.linalg.norm(orig_user_vector) > 0 and np.linalg.norm(u_user_vector) > 0:
            cos_sim = np.dot(orig_user_vector, u_user_vector) / (np.linalg.norm(orig_user_vector) * np.linalg.norm(u_user_vector))
            cosine_similarities.append(cos_sim)
    
    error_metrics = {}
    
    if cosine_similarities:
        error_metrics['mean_cosine_similarity'] = np.mean(cosine_similarities)
        error_metrics['min_cosine_similarity'] = np.min(cosine_similarities)
        error_metrics['max_cosine_similarity'] = np.max(cosine_similarities)
        error_metrics['std_cosine_similarity'] = np.std(cosine_similarities)
    
    # Method 2: projection error by projecting original matrix onto column space of U
    U_projection = U_matrix @ U_matrix.T @ original_matrix
    projection_error = original_matrix - U_projection
    
    # Frobenius norm error
    frob_error = np.linalg.norm(projection_error, 'fro')
    frob_norm_original = np.linalg.norm(original_matrix, 'fro')
    relative_frob_error = frob_error / frob_norm_original
    
    error_metrics['projection_frobenius_error'] = frob_error
    error_metrics['relative_projection_frobenius_error'] = relative_frob_error
    
    # RMSE
    rmse = np.sqrt(np.mean(projection_error**2))
    error_metrics['projection_rmse'] = rmse
    
    # MAE
    mae = np.mean(np.abs(projection_error))
    error_metrics['projection_mae'] = mae
    
    # Method 3: information retention rate
    # Compute variance of original matrix
    original_variance = np.var(original_matrix)
    projection_variance = np.var(U_projection)
    variance_retention = projection_variance / original_variance if original_variance > 0 else 0
    
    error_metrics['variance_retention_rate'] = variance_retention
    error_metrics['information_loss_rate'] = 1 - variance_retention
    
    # Method 4: reconstruction quality using U
    # Assume reconstruction: X_reconstructed = U @ V^T
    # V solved by least squares: V = (U^T @ U)^(-1) @ U^T @ X
    try:
        U_pinv = np.linalg.pinv(U_matrix)
        V_estimated = U_pinv @ original_matrix
        reconstructed_matrix = U_matrix @ V_estimated
        
        # Reconstruction error
        reconstruction_error = original_matrix - reconstructed_matrix
        recon_frob_error = np.linalg.norm(reconstruction_error, 'fro')
        recon_relative_error = recon_frob_error / frob_norm_original
        
        error_metrics['reconstruction_frobenius_error'] = recon_frob_error
        error_metrics['relative_reconstruction_frobenius_error'] = recon_relative_error
        error_metrics['reconstruction_rmse'] = np.sqrt(np.mean(reconstruction_error**2))
        error_metrics['reconstruction_mae'] = np.mean(np.abs(reconstruction_error))
        
    except np.linalg.LinAlgError:
        error_metrics['reconstruction_error'] = "Reconstruction error cannot be computed (singular matrix)"
    
    return error_metrics

def compare_svd_methods(original_matrix, num_users, d_values=None):
    """
    Compare standard SVD and ExtractUserFeatures across dimensions.
    
    Args:
    original_matrix: original matrix
    num_users: number of users
    d_values: list of dimensions to test
    
    Returns:
    comparison_results: result dictionary
    """
    if d_values is None:
        d_values = [5, 10, 15, 20, 25, 30]
    
    comparison_results = {}
    
    for d in d_values:
        if d > min(original_matrix.shape):
            continue
        
        print(f"\n=== Test dimension d={d} ===")
        
        # Method 1: standard SVD reconstruction
        u, s, vt = np.linalg.svd(original_matrix[:num_users, :], full_matrices=False)
        u_d = u[:, :d]
        s_d = s[:d]
        vt_d = vt[:d, :]
        svd_reconstructed = u_d @ np.diag(s_d) @ vt_d
        
        # Method 2: ExtractUserFeatures approach
        u_feat, s_feat, vt_feat = np.linalg.svd(original_matrix[:num_users, :], full_matrices=False)
        u_feat_d = u_feat[:, :d-1]
        u_feat_d = normalize(u_feat_d, axis=1, norm='l2')
        U_feat = np.concatenate((u_feat_d, np.ones((num_users, 1))), axis=1) / np.sqrt(2)
        
        # Compute errors for both methods
        svd_errors = measure_svd_error(original_matrix[:num_users, :], svd_reconstructed, method='all')
        feat_errors = measure_extract_user_features_error(original_matrix[:num_users, :], U_feat, method='all')
        
        comparison_results[d] = {
            'standard_svd': svd_errors,
            'extract_user_features': feat_errors
        }
        
        print(f"Standard SVD:")
        print(f"  Frobenius error: {svd_errors['frobenius_error']:.6f}")
        print(f"  Relative Frobenius error: {svd_errors['relative_frobenius_error']:.6f}")
        print(f"  RMSE: {svd_errors['rmse']:.6f}")
        
        print(f"ExtractUserFeatures:")
        print(f"  Projection Frobenius error: {feat_errors['projection_frobenius_error']:.6f}")
        print(f"  Relative projection Frobenius error: {feat_errors['relative_projection_frobenius_error']:.6f}")
        print(f"  Projection RMSE: {feat_errors['projection_rmse']:.6f}")
        if 'mean_cosine_similarity' in feat_errors:
            print(f"  Mean cosine similarity: {feat_errors['mean_cosine_similarity']:.6f}")
        print(f"  Variance retention rate: {feat_errors['variance_retention_rate']:.6f}")
    
    return comparison_results

# Main section includes error analysis
if __name__ == "__main__":
    # Load original data
    X = np.load('OffClusBandit/data/datasets/ml_1000user_1000item.npy')
    print(f"Original matrix shape: {X.shape}")
    
    # Analyze SVD quality
    print("=== SVD Quality Analysis ===")
    results = analyze_svd_quality(X)
    
    # Compare methods
    print("\n=== Method Comparison Analysis ===")
    comparison_results = compare_svd_methods(X, num_users=1000, d_values=[10, 15, 20, 25])
    
    # Test your ExtractUserFeatures method
    print("\n=== ExtractUserFeatures Method Test ===")
    U = ExtractUserFeatures(num_users=1000, d=20, filename='OffClusBandit/data/datasets/ml_1000user_1000item.npy')
    print(f"U matrix shape: {U.shape}")
    
    # Compute errors
    error_metrics = measure_extract_user_features_error(X[:1000, :], U, method='all')
    print("Error metrics of ExtractUserFeatures method:")
    for metric, value in error_metrics.items():
        if isinstance(value, str):
            print(f"  {metric}: {value}")
        else:
            print(f"  {metric}: {value:.6f}")
    
    # Plot error analysis figure
    plot_svd_error_analysis(results)


