

import os, sys
import shutil
import gzip
import pickle
import argparse
import json
from tqdm import tqdm
from subprocess import check_output

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt
import pandas as pd

from scipy.stats import spearmanr

sys.path.append('..')

from projections import real_sph_ift, ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor
from utils.data_getter import get_data_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.protein import *
from utils.data_utils import MNISTDatasetWithConditioning

from loss_functions import *

import umap
import matplotlib as mlp

import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_score, completeness_score, silhouette_score

ELEV = 17.0
AZIM = -60

# colormap with 55 distinct colors generated by distinctipy
from distinctipy import distinctipy
COLORMAP_55 = distinctipy.get_colors(55, exclude_colors=None, return_excluded=False, pastel_factor=0, n_attempts=1000, colorblind_type=None, rng=42)

def latent_space_prediction(train_invariants_ND, train_labels_N, valid_invariants_ND, valid_labels_N, classifier='LR', optimize_hyps=False, data_percentage=100):
    
    if classifier == 'LR':
        estimator = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
        hyperparams = {'C': [0.1, 0.5, 1.0, 5.0, 10.0]}
    elif classifier == 'RF':
        estimator = RandomForestClassifier()
        hyperparams = {'min_samples_leaf': [1, 2, 5, 10, 20, 50, 100]}
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    model = model.fit(train_invariants_ND, train_labels_N)
    
    predictions = model.predict_proba(valid_invariants_ND)
    onehot_predictions = np.argmax(predictions, axis=1)
    
    return classification_report(valid_labels_N, onehot_predictions, output_dict=True)

def latent_space_model_retrieval_shrec17(train_invariants_ND, train_labels_N, test_invariants_ND, test_ids_N, resdir, classifier='LR', optimize_hyps=False, data_percentage=100):
    
    if classifier == 'LR':
        estimator = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
        hyperparams = {'C': [0.1, 0.5, 1.0, 5.0, 10.0]}
    elif classifier == 'RF':
        estimator = RandomForestClassifier()
        hyperparams = {'min_samples_leaf': [1, 2, 5, 10, 20, 50, 100]}
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    model = model.fit(train_invariants_ND, train_labels_N)
    
    predictions = model.predict_proba(test_invariants_ND)
    predictions_class = np.argmax(predictions, axis=1)

    for i in range(test_ids_N.shape[0]):
        if i % 100 == 0:
            print("{}/{}    ".format(i, test_ids_N.shape[0]), end="\r")
        idfile = os.path.join(resdir, test_ids_N[i])

        retrieved = [(predictions[j, predictions_class[j]], test_ids_N[j]) for j in range(range(test_ids_N.shape[0])) if predictions_class[j] == predictions_class[i]]
        retrieved = sorted(retrieved, reverse=True)
        retrieved = [i for _, i in retrieved]

        with open(idfile, "w") as f:
            f.write("\n".join(retrieved))
    
    print(check_output(["nodejs", "evaluate.js", os.path.join("..", log_dir) + "/"], cwd="evaluator").decode("utf-8"))
    

def purity_score(labels_true, labels_pred):
    contingency_matrix = sklearn.metrics.cluster.contingency_matrix(labels_true, labels_pred)
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--model_type', type=str)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--seed', type=int, default=1000005)

    args = parser.parse_args()

    if args.split == 'train':
        MARKER_SCALING = 0.2
    else:
        MARKER_SCALING = 0.75
    
    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    device = 'cpu'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)
    

    if 'mnist' in args.model_dir:
        data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)
    elif 'shrec17' in args.model_dir:
        data_irreps = (6*o3.Irreps.spherical_harmonics(hparams['lmax'], 1)).sort().irreps.simplify()
    elif 'zernicke' in args.model_dir:
        # filter by desired lmax and channels
        OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
        rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
        mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
        data_irreps = o3.Irreps(str(mul_rst))

    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in data_irreps.ls])

    full_cosine_loss_fn = eval(NAME_TO_LOSS_FN['cosine'])(data_irreps, device)

    per_l_cosine_loss_fn_dict = {}
    for irr in data_irreps:
        per_l_cosine_loss_fn_dict[str(irr.ir)] = eval(NAME_TO_LOSS_FN['cosine'])(o3.Irreps(str(irr)), device)
    
    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, args.split)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = torch.tensor(arrays['images_NF'])
    rec_images_NF = torch.tensor(arrays['rec_images_NF'])
    ids_N = arrays['ids_N']

    N = labels_N.shape[0]

    ## compute powers
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)
    orig_powers_N = torch.einsum('bf,bf,f->b', images_NF, images_NF, 1.0 / (2*ls_indices + 1)).numpy()
    rec_powers_N = torch.einsum('bf,bf,f->b', rec_images_NF, rec_images_NF, 1.0 / (2*ls_indices + 1)).numpy()

    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s.npz' % (model_type_str, args.split))):
        # compute full signal cosine loss
        full_cosines_N = []
        for n in tqdm(range(N)):
            full_cosines_N.append(full_cosine_loss_fn(rec_images_NF[n, :], images_NF[n, :]).item())
        full_cosines_N = np.array(full_cosines_N)

        # compute per-l-value cosine loss
        per_l_cosines_N_dict = {}
        for irr in data_irreps:
            temp_l_cosines_N = []
            for n in tqdm(range(N)):
                temp_l_cosines_N.append(per_l_cosine_loss_fn_dict[str(irr.ir)](rec_images_NF[n, ls_indices == irr.ir.l], images_NF[n, ls_indices == irr.ir.l]).item())
            per_l_cosines_N_dict[str(irr.ir)] = np.array(temp_l_cosines_N)
        
        # # save everything
        # np.savez(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s.npz' % (model_type_str, args.split)),
        #                 full_cosines_N = full_cosines_N,
        #                 **per_l_cosines_N_dict)
    else:
        cosines = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s.npz' % (model_type_str, args.split)))
        full_cosines_N = cosines['full_cosines_N']
        per_l_cosines_N_dict = {}
        for irr in data_irreps:
            per_l_cosines_N_dict[str(irr.ir)] = cosines[str(irr.ir)]
    
    # make directories if they do not exists
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'loss_distributions')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'loss_distributions'))
    
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'latent_space_viz')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'latent_space_viz'))
    
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'latent_space_classification')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'latent_space_classification'))
    

    ## plot the powers in a scatterplot 
    print('Plotting correlation of real vs. reconstructed powers', file=sys.stderr)
    from scipy.stats import pearsonr, spearmanr
    coeff = np.polyfit(orig_powers_N, rec_powers_N, deg=1)[0]
    y_fit = coeff * orig_powers_N

    sp_r = spearmanr(orig_powers_N, rec_powers_N)
    pe_r = pearsonr(orig_powers_N, rec_powers_N)

    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], marker='o', ls='', color='darkblue', markerfacecolor='darkblue', label='SP-R: %.3f, p-val: %.3f' % (sp_r[0], sp_r[1])),
                    Line2D([0], [0], marker='o', ls='', color='darkblue', markerfacecolor='darkblue', label='PE-R: %.3f, p-val: %.3f' % (pe_r[0], pe_r[1]))]

    plt.scatter(orig_powers_N, rec_powers_N, color='darkblue')
    plt.plot(orig_powers_N, y_fit, color='darkblue')
    plt.title('Total Power')
    plt.xlabel('Original Signal')
    plt.ylabel('Reconstructed Signal')
    plt.legend(handles=legend_elements)
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/total_power_comparison%s-split=%s.png' % (model_type_str, args.split)))
    plt.close()

    # plot the distribution of cosine loss (histogram)
    print('Plotting distribution of cosine loss', file=sys.stderr)
    plt.figure(figsize=(10, 6))
    plt.hist(full_cosines_N, label='Mean = %.3f' % (np.mean(full_cosines_N)))
    plt.xlabel('Cosine loss')
    plt.ylabel('Count')
    plt.title('Full signal')
    plt.xlim([0, 2])
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/full_signal_cosines%s-split=%s.png' % (model_type_str, args.split)))
    plt.close()

    # plot the distribution of cosine loss per label
    print('Plotting distributions of cosine losses per label', file=sys.stderr)
    if 'mnist' in args.model_dir:
        nrows = 2
    elif 'shrec17' in args.model_dir:
        nrows = 11
    elif 'zernicke' in args.model_dir:
        nrows = 4
    
    fig, axs = plt.subplots(figsize=(20, nrows*4), nrows=nrows, ncols=5, sharex=True, sharey=True)
    axs = axs.flatten()

    for l_i, label in enumerate(sorted(list(set(list(labels_N))))):
        axs[l_i].hist(full_cosines_N[labels_N == label], label='Mean = %.3f' % (np.mean(full_cosines_N[labels_N == label])))
        axs[l_i].set_xlabel('Cosine loss')
        axs[l_i].set_ylabel('Count')
        if 'mnist' in args.model_dir or 'shrec17' in args.model_dir:
            axs[l_i].set_title('Label = %d' % (label))
        else:
            axs[l_i].set_title('Label = %s' % (ind_to_aa[label]))
        axs[l_i].legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/full_signal_cosines_per_label%s-split=%s.png' % (model_type_str, args.split)))


    # plot the distribution of cosine loss per l-value (lmax+1 histograms)
    print('Plotting distributions of cosine losses per l-value', file=sys.stderr)
    fig, axs = plt.subplots(figsize=(16, 12), nrows=4, ncols=4, sharex=True, sharey=True)
    axs = axs.flatten()

    for l_i, irr in enumerate(data_irreps):
        axs[l_i].hist(per_l_cosines_N_dict[str(irr.ir)], label='Mean = %.3f' % (np.mean(per_l_cosines_N_dict[str(irr.ir)])))
        axs[l_i].set_xlabel('Cosine loss')
        axs[l_i].set_ylabel('Count')
        axs[l_i].set_title('l = %d' % (irr.ir.l))
        axs[l_i].legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/per_l_cosines%s-split=%s.png' % (model_type_str, args.split)))


    ## umap of invariants
    print('Computing UMAP of invariants', file=sys.stderr)
    if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split))):
        lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)))
    else:
        lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit_transform(invariants_ND)
        np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N2)

    if 'mnist' in args.model_dir:
        colors_10 = plt.get_cmap('tab10').colors
    elif 'shrec17' in args.model_dir:
        colors_10 = COLORMAP_55
    elif 'zenicke' in args.model_dir:
        colors_10 = plt.get_cmap('tab20').colors
    colors_N = list(map(lambda i: colors_10[i], labels_N))

    from matplotlib.lines import Line2D
    if 'mnist' in args.model_dir or 'shrec17' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
    elif 'zernicke' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    # plt.legend(handles=legend_elements)
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants%s-split=%s.png' % (model_type_str, args.split)))
    plt.close()

    # color by power - orig power
    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=orig_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_orig_power%s-split=%s.png' % (model_type_str, args.split)))
    plt.close()

    # color by power - rec power
    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=rec_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_rec_power%s-split=%s.png' % (model_type_str, args.split)))
    plt.close()


    ## 3D umap of invariants
    print('Computing 3D UMAP of invariants', file=sys.stderr)
    if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s.npy' % (model_type_str, args.split))):
        lower_dim_invariants_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s.npy' % (model_type_str, args.split)))
    else:
        lower_dim_invariants_N3 = umap.UMAP(n_components=3, random_state=42).fit_transform(invariants_ND)
        np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N3)

    if 'mnist' in args.model_dir:
        colors_10 = plt.get_cmap('tab10').colors
    elif 'shrec17' in args.model_dir:
        colors_10 = COLORMAP_55
    elif 'zenicke' in args.model_dir:
        colors_10 = plt.get_cmap('tab20').colors
    colors_N = list(map(lambda i: colors_10[i], labels_N))

    from matplotlib.lines import Line2D
    if 'mnist' in args.model_dir or 'shrec17' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
    elif 'zernicke' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    # plt.legend(handles=legend_elements)
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D%s-elev=%.1f-azim=%d-split=%s.png' % (model_type_str, ELEV, AZIM, args.split)))
    plt.close()

    # color by power - orig power
    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=orig_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D_by_orig_power%s-elev=%.1f-azim=%d-split=%s.png' % (model_type_str, ELEV, AZIM, args.split)))
    plt.close()

    # color by power - rec power
    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=rec_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D_by_rec_power%s-elev=%.1f-azim=%d-split=%s.png' % (model_type_str, ELEV, AZIM, args.split)))
    plt.close()


    if 'mnist' in args.model_dir:
        ## umap of frames
        print('Computing UMAP of learned frames...')
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s.npy' % (model_type_str, args.split))):
            lower_dim_learned_frames_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s.npy' % (model_type_str, args.split)))
        else:
            lower_dim_learned_frames_N2 = umap.UMAP(random_state=42).fit_transform(learned_frames_N9)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_learned_frames_N2)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_learned_frames_N2[:, 0], lower_dim_learned_frames_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_frames%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()


        ## UMAP of invariants plus frames (only with NRNR arbitrary human frame)
        print('Computing UMAP of invariants plus learned frames...')
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s.npy' % (model_type_str, args.split))):
            lower_dim_learned_inv_plus_frames_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s.npy' % (model_type_str, args.split)))
        else:
            lower_dim_learned_inv_plus_frames_N2 = umap.UMAP(random_state=42).fit_transform(np.hstack([invariants_ND, learned_frames_N9]))
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_learned_inv_plus_frames_N2)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_learned_inv_plus_frames_N2[:, 0], lower_dim_learned_inv_plus_frames_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_plus_frames%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()

        ## UMAP of invariants plus frames in 3D (only with NRNR arbitrary human frame)
        print('Computing UMAP of invariants plus learned frames...')

        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s.npy' % (model_type_str, args.split))):
            lower_dim_learned_inv_plus_frames_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s.npy' % (model_type_str, args.split)))
        else:
            lower_dim_learned_inv_plus_frames_N3 = umap.UMAP(n_components=3, random_state=42).fit_transform(np.hstack([invariants_ND, learned_frames_N9]))
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_learned_inv_plus_frames_N3)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(16, 12))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(lower_dim_learned_inv_plus_frames_N3[:, 0], lower_dim_learned_inv_plus_frames_N3[:, 1], lower_dim_learned_inv_plus_frames_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=ELEV, azim=AZIM)
        plt.tight_layout()
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_plus_frames_3D%s-elev=%.1f-azim=%d-split=%s.png' % (model_type_str, ELEV, AZIM, args.split)))
        plt.close()


        ## plot individual vectors of learned frames (only with NRNR arbitrary human frame)
        print('Plotting learned frames in 3D...')
        x_learned, y_learned, z_learned = learned_frames_N9[:, :3], learned_frames_N9[:, 3:6], learned_frames_N9[:, -6:]

        fig = plt.figure(figsize=(24, 8))

        ax = fig.add_subplot(1, 3, 1, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-60)
        ax.legend()

        ax = fig.add_subplot(1, 3, 2, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-30)
        ax.legend()

        ax = fig.add_subplot(1, 3, 3, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=0)
        ax.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/learned_frames_3d_vectors%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()


        ## plot individual second vector of learned frames (only with NRNR arbitrary human frame)
        print('Plotting second l1 vector of learned frames in 3D...')
        x_learned, y_learned, z_learned = learned_frames_N9[:, :3], learned_frames_N9[:, 3:6], learned_frames_N9[:, -6:]

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(24, 8))

        ax = fig.add_subplot(1, 3, 1, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-60)

        ax = fig.add_subplot(1, 3, 2, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-30)

        ax = fig.add_subplot(1, 3, 3, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=0)

        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/learned_second_l1_3d_vectors%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()




    ## umap of invariants for valid/test data fitted to training data
    if args.split != 'train':
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s.npy' % (model_type_str, args.split))):
            lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s.npy' % (model_type_str, args.split)))
        else:
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
            invariants_train_MD = arrays_train['invariants_ND']
            lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit(invariants_train_MD).transform(invariants_ND)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N2)

        if 'mnist' in args.model_dir:
            colors_10 = plt.get_cmap('tab10').colors
        elif 'shrec17' in args.model_dir:
            colors_10 = COLORMAP_55
        elif 'zenicke' in args.model_dir:
            colors_10 = plt.get_cmap('tab20').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        if 'mnist' in args.model_dir or 'shrec17' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
        elif 'zernicke' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        # plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_fitted_on_train%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()


    ## umap of invariants for valid/test data fitted to training data, in 3D
    if args.split != 'train':
        print('Computing UMAP of invariants fitted to training data', file=sys.stderr)
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s.npy' % (model_type_str, args.split))):
            lower_dim_invariants_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s.npy' % (model_type_str, args.split)))
        else:
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
            invariants_train_MD = arrays_train['invariants_ND']
            lower_dim_invariants_N3 = umap.UMAP(n_components=3, random_state=42).fit(invariants_train_MD).transform(invariants_ND)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s.npy' % (model_type_str, args.split)), lower_dim_invariants_N3)

        if 'mnist' in args.model_dir:
            colors_10 = plt.get_cmap('tab10').colors
        elif 'shrec17' in args.model_dir:
            colors_10 = COLORMAP_55
        elif 'zenicke' in args.model_dir:
            colors_10 = plt.get_cmap('tab20').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        if 'mnist' in args.model_dir or 'shrec17' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
        elif 'zernicke' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(16, 12))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=ELEV, azim=AZIM)
        # plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_fitted_on_train_3D%s-split=%s.png' % (model_type_str, args.split)))
        plt.close()
    

    ## linear classification in the latent space
    if args.split != 'train':
        print('Performing linear classification in the latent space')
        output_dir = os.path.join(args.model_dir, args.hash, 'latent_space_classification/classificaton_on_latent_space_default_classes%s-split=%s.csv' % (model_type_str, args.split))
        if not os.path.exists(output_dir):
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
            train_invariants_ND = arrays_train['invariants_ND']
            train_labels_N = arrays_train['labels_N']

            report = latent_space_prediction(train_invariants_ND, train_labels_N, invariants_ND, labels_N, classifier='LR', optimize_hyps=False)
            pd.DataFrame(report).to_csv(output_dir)
    

    ## Quality of clustering metrics
    print('Computing quality of clustering metrics', file=sys.stderr)
    if 'mnist' in args.model_dir:
        n_clusters = 10
    elif 'shrec17' in args.model_dir:
        n_clusters = 55
    elif 'zernicke' in args.model_dir:
        n_clusters = 20
    else:
        raise NotImplementedError()
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)

    labels_pred_N = kmeans.fit_predict(invariants_ND)

    homogeneity = homogeneity_score(labels_N, labels_pred_N)
    completeness = completeness_score(labels_N, labels_pred_N)
    purity = purity_score(labels_N, labels_pred_N)
    silhouette = silhouette_score(invariants_ND, labels_pred_N)

    table = {
        'Homogeneity': [homogeneity],
        'Completeness': [completeness],
        'Purity': [purity],
        'Silhouette': [silhouette]
    }

    pd.DataFrame(table).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/quality_of_clustering_metrics_default_classes%s-split=%s.csv' % (model_type_str, args.split)), index=None)


    # ## SHREC17 standard evaluation metrics
    # ## Following the same procedure as Spherical CNNs to generate list of retrieved models
    # if split == 'test':

    #     arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train.npz' % (model_type_str)))
    #     train_invariants_ND = arrays_train['invariants_ND']
    #     train_labels_N = arrays_train['labels_N']

    #     if hparams['perturbed']:
    #         resdir = os.path.join(args.model_dir, args.hash, split + '_perturbed')
    #     else:
    #         resdir = os.path.join(args.model_dir, args.hash, split + '_normal')
    #     if os.path.isdir(resdir):
    #         shutil.rmtree(resdir)
    #     os.mkdir(resdir)

    #     latent_space_model_retrieval_shrec17(train_invariants_ND, train_labels_N, invariants_ND, ids_N, resdir, classifier='LR', optimize_hyps=False)
