import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np

matplotlib.use('Agg')
plt.switch_backend('agg')

def generate_scatter(reduced_dims, run_save_dir, name = ""):
    # Generate scatter plot
    # Splitting the data into x and y coordinates
    x = reduced_dims[:, 0]
    y = reduced_dims[:, 1]

    # Creating the scatter plot
    plt.figure(figsize=(10, 6))  # You can adjust the size of the figure
    plt.scatter(x, y, s=0.5)  # s is the size of each point
    plt.title(name)
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.savefig(os.path.join(run_save_dir, "scatter.png"))
    plt.close()
    
def generate_cluster_plot(reduced_dims, clusters, run_save_dir, name=""):
    # Generate cluster plot
    # Splitting the PCA data into x and y coordinates
    x = reduced_dims[:, 0]
    y = reduced_dims[:, 1]

    # Creating the scatter plot
    plt.figure(figsize=(10, 6))
    plt.scatter(x, y, s=0.5, c=clusters, cmap='viridis')  # Color by cluster labels
    
    # Minimize whitespace by tightening the layout
    plt.tight_layout(pad=0.5)
    
    # Setting the title and labels with bold title
    #plt.title(name, pad=20, fontweight='bold')  # Title is now bold
    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")

    # Calculate the cluster count excluding noise if applicable (-1 label is often used for noise)
    cluster_count = len(np.unique(clusters)) - (1 if -1 in clusters else 0)

    # Annotate the cluster count in the bottom left corner
    plt.annotate(f'Cluster Count: {cluster_count}', xy=(0.01, 0.01), xycoords='axes fraction',
                 horizontalalignment='left', verticalalignment='bottom', fontsize=20, fontweight='bold')

    # Save the plot with minimal padding and without the colorbar
    plt.savefig(os.path.join(run_save_dir, "cluster.png"), bbox_inches='tight', pad_inches=0.1)
    plt.close()