import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform

def load_and_process_prototypes(filenames):
    try:
        matrices = [np.load(filename) for filename in filenames]
        mean_matrix = np.mean(matrices, axis=0)
        normalized_matrix = mean_matrix / mean_matrix.sum(axis=1, keepdims=True)
        return normalized_matrix
    except Exception as e:
        print(f"Error in loading or processing files: {e}")
        return None

# Directory and file setup
directory = "SEED_V_result/PLL_confusion/main/PGNA_PL/scheduler_True/optimizer_sgd/lr_0.01/confidence_False/beta_parameter_3.0"
# embedding ="_Semantic"
embedding = "_Russel"

name = "prototypes" + embedding + ".npy"
filenames = [f"{directory}/run_{i}/{name}" for i in range(1, 6)]

# Data processing
X = load_and_process_prototypes(filenames)
if X is not None:
    pca = PCA(n_components=2)
    X_2d = pca.fit_transform(X)

    # Visualization settings
    categories = ['Disgust', 'Fear', 'Sad', 'Neutral', 'Happy']
    colors = ['purple', 'blue', 'green', 'orange', 'red']
    # line_colors = ['magenta', 'cyan', 'lime', 'gold', 'deepskyblue']
    line_colors = ['magenta', 'darkcyan', 'darkgreen', 'gold', 'deepskyblue']  # Updated colors

    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    ax.set_facecolor('whitesmoke')  # Set the background to a light color
    # Set larger font sizes for axis ticks
    ax.tick_params(axis='both', which='major', labelsize=16)
    # Distance matrix calculation
    distances = squareform(pdist(X_2d))
    n = len(X_2d)
    for i in range(n):
        for j in range(i + 1, n):
            plt.plot([X_2d[i, 0], X_2d[j, 0]], [X_2d[i, 1], X_2d[j, 1]], linestyle='--', color=line_colors[i], alpha=0.5)
            mid_point = (X_2d[i] + X_2d[j]) / 2
            plt.text(mid_point[0], mid_point[1], f"{distances[i, j]:.2f}", fontsize=16, color=line_colors[i])

    # Plotting points and labels
    for i, (x, y) in enumerate(X_2d):
        plt.scatter(x, y, color=colors[i], s=100)
        ha = 'right' if x > 0 else 'left'
        plt.text(x, y, categories[i], color=colors[i], fontsize=16, horizontalalignment=ha, fontweight='bold', zorder=3)

    # Title configuration
    if embedding == "_Semantic":
        plt.title('PCA Visualization of Prototype Categories under Semantic Distribution', fontsize=16)
    else:
        plt.title('PCA Visualization of Prototype Categories under Russel Distribution', fontsize=16)

    # Save and show plot
    plt.savefig(directory + "/visualization_color" + embedding + ".png", dpi=300)
    plt.show()
else:
    print("Failed to process data, check file paths and data integrity.")
