import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import torch

def plot_zs(zhats, ys, d_name, nmi, sc, method_name):
    tsne_results = zhats
    labels = ys.tolist()

    labels_dict = dict(zip(list(set(labels)), range(len(set(labels)))))
    labels = [labels_dict[i] for i in labels]
    labels = np.array(labels)
    # Define a marker for each label
    # marker_dict = {0: 'x', 1: '*', 2: '+', 3: 'D'}

    # Font size settings
    title_fs = 15
    label_fs = 16
    tick_fs = 12
    legend_fs = 14

    fig, ax = plt.subplots(figsize=(4, 4))

    # Scatter plot with different markers
    for lbl in np.unique(labels):
        idx = labels == lbl
        ax.scatter(
            tsne_results[idx, 0],
            tsne_results[idx, 1],
            label=f'Class {lbl}',
            # marker=marker_dict[lbl],
            s=20,
            ##        alpha=0.7
            linewidths=1,
        )

    # Title and axis labels with adjustable font sizes
    ax.set_title(f'{method_name}, NMI: {nmi*100:.2f}, SC: {sc*100:.2f}', fontsize=title_fs)

    # Legend with adjustable font sizes
    # leg = ax.legend(
    #     ##    title='Label',
    #     ##    title_fontsize=legend_fs,
    #     fontsize=legend_fs,
    #     frameon=True,
    #     ncol=1  # number of columns in legend
    # )
    ##leg.get_frame().set_linewidth()

    # Hide ticks and tick labels
    ax.set_xticks([])
    ax.set_yticks([])

    # Keep the plot border (spines) visible
    for spine in ax.spines.values():
        spine.set_visible(True)

    plt.tight_layout()
    plt.savefig(f'./plot_res/{method_name}-{nmi*100:.2f}-{d_name}.pdf')


if __name__ == "__main__":
    # train_df = pd.read_csv('./for_plot_clip_500_train_df_tsne-knn-l2.csv')
    # train_save = torch.load('./for_plot_clip_500_train_df_tsne-knn-l2.tar')
    #
    # autodv_method = 'UMAP'
    # for d_name in train_save.keys():
    #     print(d_name)
    #
    #     save_dict = train_save[d_name]
    #     y = save_dict['y'].cpu().numpy()
    #     pred_z = save_dict['pred'].cpu().numpy()
    #     umap_z = save_dict['umap_gt'].cpu().numpy()
    #     tsne_z = save_dict['tsne_gt'].cpu().numpy()
    #
    #     d_name = d_name[0]
    #
    #     tsne_nmi = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_nmi'].values[0]
    #     tsne_sc = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_sc'].values[0]
    #     umap_nmi = train_df.loc[train_df['d_name'] == d_name]['gt_umap_nmi'].values[0]
    #     umap_sc = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_sc'].values[0]
    #     pred_nmi = train_df.loc[train_df['d_name'] == d_name]['pred_tsne_nmi'].values[0]
    #     pred_sc = train_df.loc[train_df['d_name'] == d_name]['pred_tsne_nmi'].values[0]
    #
    #     plot_zs(pred_z, y, d_name, pred_nmi,  pred_sc, 'AutoDV-tSNE')
    #     # plot_zs(umap_z, y, d_name, umap_nmi,  umap_sc, 'UMAP*')
    #     # plot_zs(tsne_z, y, d_name, tsne_nmi,  tsne_sc, 't-SNE*')


    train_df = pd.read_csv('./2for_plot_clip_500_test_df_tsne-knn-l2.csv')
    train_save = torch.load('./2for_plot_clip_500_test_df_tsne-knn-l2.tar')

    autodv_method = 'UMAP'
    for d_name in train_save.keys():
        print(d_name)

        save_dict = train_save[d_name]
        y = save_dict['y'].cpu().numpy()
        pred_z = save_dict['pred'].cpu().numpy()
        umap_z = save_dict['umap_gt'].cpu().numpy()
        tsne_z = save_dict['tsne_gt'].cpu().numpy()

        d_name = d_name[0]

        tsne_nmi = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_nmi'].values[0]
        tsne_sc = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_sc'].values[0]
        umap_nmi = train_df.loc[train_df['d_name'] == d_name]['gt_umap_nmi'].values[0]
        umap_sc = train_df.loc[train_df['d_name'] == d_name]['gt_tsne_sc'].values[0]
        pred_nmi = train_df.loc[train_df['d_name'] == d_name]['pred_tsne_nmi'].values[0]
        pred_sc = train_df.loc[train_df['d_name'] == d_name]['pred_tsne_nmi'].values[0]

        plot_zs(pred_z, y, d_name, pred_nmi,  pred_sc, 'AutoDV-tSNE')
        # plot_zs(umap_z, y, d_name, umap_nmi,  umap_sc, 'UMAP*')
        # plot_zs(tsne_z, y, d_name, tsne_nmi,  tsne_sc, 't-SNE*')

