import matplotlib.pyplot as plt
import numpy as np

plt.style.use('seaborn-dark-palette')
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import os
import os.path as osp

def embedding(feats, labels, n_pca=10):
    assert feats.shape[0] == labels.shape[0]
    assert feats.shape[0] > 0

    feature = feats.cpu().detach().numpy()
    label = labels.cpu().detach().numpy()
    # Using PCA to reduce dimension to a reasonable dimension as recommended in
    # https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html
    feature = PCA(n_components=n_pca).fit_transform(feature)
    feature_embedded = TSNE(n_components=2).fit_transform(feature)
    return feature_embedded, label
    # print(f"feature shape: {feature.shape}")


def plot_features(featuresList, labelsList, num_classes, titles, dirname, filename):
    colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9']
    markers = ['o', 's', '^', 'v', 'D', 'p', '*', 'X', '+', 'H']
    def plot_features_meta(ax, features, labels, title):
        for label_idx in range(num_classes):
            class_instances = features[labels == label_idx]
            # Calculate center of mass for the instances of this class
            center_x = np.mean(class_instances[:, 0])
            center_y = np.mean(class_instances[:, 1])

            # Calculate radius by considering the distance from center to farthest instance
            distances = [np.linalg.norm(item - np.array([[center_x, center_y]]), axis=1) for item in class_instances]
            radius = np.quantile(distances, 0.9)

            # Draw a circle for the boundary of the class instances
            boundary_circle = plt.Circle((center_x, center_y), radius, color=colors[label_idx], fill=False, alpha=0.3)
            ax.add_patch(boundary_circle)  # Add boundary circle as a patch

            # Draw a circle for the filled background of the class instances
            filled_circle = plt.Circle((center_x, center_y), radius, color=colors[label_idx], fill=True, alpha=0.05)
            ax.add_patch(filled_circle)  # Add filled circle as a patch

            ax.scatter(
                features[labels == label_idx, 0],
                features[labels == label_idx, 1],
                c=colors[label_idx],
                marker=markers[label_idx],
                s=50,  # Marker size
                label=str(label_idx),  # Use label as legend entry
                alpha=0.6  # Adjust transparency for better visualization
            )
        ax.set_xticks([])  # Remove x-axis ticks and labels
        ax.set_yticks([])  # Remove y-axis ticks and labels
        ax.set_title(title, fontsize=22)

    fig, axes = plt.subplots(nrows=1, ncols=len(featuresList), figsize=(32, 8))
    for idx, ax in enumerate(axes):
        plot_features_meta(ax, featuresList[idx], labelsList[idx], title=titles[idx])
    legend = plt.legend(loc='lower right', title='Classes',  bbox_to_anchor=(0, -0.20), ncol=num_classes,
                        fontsize=18,  title_fontsize=20)
    for artist in legend.legend_handles:
        if isinstance(artist, plt.Line2D):  # Check if the artist is a scatter plot handle
            artist.set_marker(markers[int(artist.get_label())])  # Set marker style
            artist.set_markersize(20)  # Set`` marker size
    if not osp.exists(dirname):
        os.makedirs(dirname)
    save_name = osp.join(dirname, f'{filename}.png')
    print(f'Pic saved to {save_name}.')
    plt.savefig(save_name, bbox_inches='tight')
    plt.close()