
'''
General model evaluation
'''

import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt
import pandas as pd

from scipy.stats import spearmanr

sys.path.append('..')

from projections import real_sph_ift, ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor
from utils.data_getter import get_data_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.protein import *
from utils.data_utils import MNISTDatasetWithConditioning

from loss_functions import *

import umap
import matplotlib as mlp

import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
from sklearn.cluster import KMeans
from sklearn.metrics import homogeneity_score, completeness_score, silhouette_score

ELEV = 17.0
AZIM = -60

def latent_space_prediction(train_invariants_ND, train_labels_N, valid_invariants_ND, valid_labels_N, classifier='LR', optimize_hyps=False, data_percentage=100):
    
    if classifier == 'LR':
        estimator = LogisticRegression(multi_class='multinomial', solver='lbfgs', max_iter=1000)
        hyperparams = {'C': [0.1, 0.5, 1.0, 5.0, 10.0]}
    elif classifier == 'RF':
        estimator = RandomForestClassifier()
        hyperparams = {'min_samples_leaf': [1, 2, 5, 10, 20, 50, 100]}
    
    if optimize_hyps:
        model = GridSearchCV(estimator, hyperparams)
    else:
        model = estimator
    
    model = model.fit(train_invariants_ND, train_labels_N)
    
    predictions = model.predict_proba(valid_invariants_ND)
    onehot_predictions = np.argmax(predictions, axis=1)
    
    return classification_report(valid_labels_N, onehot_predictions, output_dict=True)

def purity_score(labels_true, labels_pred):
    contingency_matrix = sklearn.metrics.cluster.contingency_matrix(labels_true, labels_pred)
    return np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_type', type=str, default='RR-avg_sqrt_power')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--model_type', type=str)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str) 
    parser.add_argument('--seed', type=int, default=1000005) # 1000005, 1000006

    args = parser.parse_args()

    if args.split == 'train':
        MARKER_SCALING = 0.2
    else:
        MARKER_SCALING = 0.75
    
    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    device = 'cpu'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)
    

    if 'mnist' in args.model_dir:
        data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)
    elif 'zernicke' in args.model_dir:
        # filter by desired lmax and channels
        OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
        rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
        mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
        data_irreps = o3.Irreps(str(mul_rst))

    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in data_irreps.ls])

    full_cosine_loss_fn = eval(NAME_TO_LOSS_FN['cosine'])(data_irreps, device)

    per_l_cosine_loss_fn_dict = {}
    for irr in data_irreps:
        per_l_cosine_loss_fn_dict[str(irr.ir)] = eval(NAME_TO_LOSS_FN['cosine'])(o3.Irreps(str(irr)), device)
    
    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = torch.tensor(arrays['images_NF'])
    rec_images_NF = torch.tensor(arrays['rec_images_NF'])

    N = labels_N.shape[0]

    ## compute powers
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)
    orig_powers_N = torch.einsum('bf,bf,f->b', images_NF, images_NF, 1.0 / (2*ls_indices + 1)).numpy()
    rec_powers_N = torch.einsum('bf,bf,f->b', rec_images_NF, rec_images_NF, 1.0 / (2*ls_indices + 1)).numpy()

    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type))):
        # compute full signal cosine loss
        full_cosines_N = []
        for n in tqdm(range(N)):
            full_cosines_N.append(full_cosine_loss_fn(rec_images_NF[n, :], images_NF[n, :]).item())
        full_cosines_N = np.array(full_cosines_N)

        # compute per-l-value cosine loss
        per_l_cosines_N_dict = {}
        for irr in data_irreps:
            temp_l_cosines_N = []
            for n in tqdm(range(N)):
                temp_l_cosines_N.append(per_l_cosine_loss_fn_dict[str(irr.ir)](rec_images_NF[n, ls_indices == irr.ir.l], images_NF[n, ls_indices == irr.ir.l]).item())
            per_l_cosines_N_dict[str(irr.ir)] = np.array(temp_l_cosines_N)
        
        # # save everything
        # np.savez(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)),
        #                 full_cosines_N = full_cosines_N,
        #                 **per_l_cosines_N_dict)
    else:
        cosines = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/cosines%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
        full_cosines_N = cosines['full_cosines_N']
        per_l_cosines_N_dict = {}
        for irr in data_irreps:
            per_l_cosines_N_dict[str(irr.ir)] = cosines[str(irr.ir)]
    
    # make directories if they do not exists
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'loss_distributions')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'loss_distributions'))
    
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'latent_space_viz')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'latent_space_viz'))
    
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'latent_space_classification')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'latent_space_classification'))
    

    ## plot the powers in a scatterplot 
    from scipy.stats import pearsonr, spearmanr
    coeff = np.polyfit(orig_powers_N, rec_powers_N, deg=1)[0]
    y_fit = coeff * orig_powers_N

    sp_r = spearmanr(orig_powers_N, rec_powers_N)
    pe_r = pearsonr(orig_powers_N, rec_powers_N)

    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], marker='o', ls='', color='darkblue', markerfacecolor='darkblue', label='SP-R: %.3f, p-val: %.3f' % (sp_r[0], sp_r[1])),
                    Line2D([0], [0], marker='o', ls='', color='darkblue', markerfacecolor='darkblue', label='PE-R: %.3f, p-val: %.3f' % (pe_r[0], pe_r[1]))]

    plt.scatter(orig_powers_N, rec_powers_N, color='darkblue')
    plt.plot(orig_powers_N, y_fit, color='darkblue')
    plt.title('Total Power')
    plt.xlabel('Original Signal')
    plt.ylabel('Reconstructed Signal')
    plt.legend(handles=legend_elements)
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/total_power_comparison%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    plt.close()

    # plot the distribution of cosine loss (histogram)
    plt.figure(figsize=(10, 6))
    plt.hist(full_cosines_N, label='Mean = %.3f' % (np.mean(full_cosines_N)))
    plt.xlabel('Cosine loss')
    plt.ylabel('Count')
    plt.title('Full signal')
    plt.xlim([0, 2])
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/full_signal_cosines%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    plt.close()

    # plot the distribution of cosine loss per label (10 histograms)
    if 'mnist' in args.model_dir:
        nrows = 2
    else: # aminoacid neighborhoods
        nrows = 4
    
    fig, axs = plt.subplots(figsize=(20, nrows*4), nrows=nrows, ncols=5, sharex=True, sharey=True)
    axs = axs.flatten()

    for l_i, label in enumerate(sorted(list(set(list(labels_N))))):
        axs[l_i].hist(full_cosines_N[labels_N == label], label='Mean = %.3f' % (np.mean(full_cosines_N[labels_N == label])))
        axs[l_i].set_xlabel('Cosine loss')
        axs[l_i].set_ylabel('Count')
        if 'mnist' in args.model_dir:
            axs[l_i].set_title('Label = %d' % (label))
        else:
            axs[l_i].set_title('Label = %s' % (ind_to_aa[label]))
        axs[l_i].legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/full_signal_cosines_per_label%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))


    # plot the distribution of cosine loss per l-value (lmax+1 histograms)
    fig, axs = plt.subplots(figsize=(16, 12), nrows=3, ncols=4, sharex=True, sharey=True)
    axs = axs.flatten()

    for l_i, irr in enumerate(data_irreps):
        axs[l_i].hist(per_l_cosines_N_dict[str(irr.ir)], label='Mean = %.3f' % (np.mean(per_l_cosines_N_dict[str(irr.ir)])))
        axs[l_i].set_xlabel('Cosine loss')
        axs[l_i].set_ylabel('Count')
        axs[l_i].set_title('l = %d' % (irr.ir.l))
        axs[l_i].legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'loss_distributions/per_l_cosines%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))


    ## umap of invariants
    if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
        lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
    else:
        lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit_transform(invariants_ND)
        np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_invariants_N2)

    if 'mnist' in args.model_dir:
        colors_10 = plt.get_cmap('tab10').colors
    else:
        colors_10 = plt.get_cmap('tab20').colors
    colors_N = list(map(lambda i: colors_10[i], labels_N))

    from matplotlib.lines import Line2D
    if 'mnist' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
    else:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    plt.legend(handles=legend_elements)
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    plt.close()

    # color by power - orig power
    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=orig_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_orig_power%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    plt.close()

    # color by power - rec power
    plt.figure(figsize=(12, 8))
    plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=rec_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_rec_power%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    plt.close()


    ## 3D umap of invariants
    if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
        lower_dim_invariants_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
    else:
        lower_dim_invariants_N3 = umap.UMAP(n_components=3, random_state=42).fit_transform(invariants_ND)
        np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_invariants_N3)

    if 'mnist' in args.model_dir:
        colors_10 = plt.get_cmap('tab10').colors
    else:
        colors_10 = plt.get_cmap('tab20').colors
    colors_N = list(map(lambda i: colors_10[i], labels_N))

    from matplotlib.lines import Line2D
    if 'mnist' in args.model_dir:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
    else:
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    plt.legend(handles=legend_elements)
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D%s-elev=%.1f-azim=%d-split=%s-input_type=%s.png' % (model_type_str, ELEV, AZIM, args.split, args.input_type)))
    plt.close()

    # color by power - orig power
    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=orig_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D_by_orig_power%s-elev=%.1f-azim=%d-split=%s-input_type=%s.png' % (model_type_str, ELEV, AZIM, args.split, args.input_type)))
    plt.close()

    # color by power - rec power
    fig = plt.figure(figsize=(16, 12))
    ax = fig.add_subplot(projection='3d')
    ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=rec_powers_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
    ax.view_init(elev=ELEV, azim=AZIM)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_3D_by_rec_power%s-elev=%.1f-azim=%d-split=%s-input_type=%s.png' % (model_type_str, ELEV, AZIM, args.split, args.input_type)))
    plt.close()


    if 'mnist' in args.model_dir:
        ## umap of frames
        print('Computing UMAP of learned frames...')
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
            lower_dim_learned_frames_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
        else:
            lower_dim_learned_frames_N2 = umap.UMAP(random_state=42).fit_transform(learned_frames_N9)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_learned_frames_N2)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_learned_frames_N2[:, 0], lower_dim_learned_frames_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_frames%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()


        ## UMAP of invariants plus frames (only with NRNR arbitrary human frame)
        print('Computing UMAP of invariants plus learned frames...')
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
            lower_dim_learned_inv_plus_frames_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
        else:
            lower_dim_learned_inv_plus_frames_N2 = umap.UMAP(random_state=42).fit_transform(np.hstack([invariants_ND, learned_frames_N9]))
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_learned_inv_plus_frames_N2)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_learned_inv_plus_frames_N2[:, 0], lower_dim_learned_inv_plus_frames_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_plus_frames%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()

        ## UMAP of invariants plus frames in 3D (only with NRNR arbitrary human frame)
        print('Computing UMAP of invariants plus learned frames...')

        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
            lower_dim_learned_inv_plus_frames_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
        else:
            lower_dim_learned_inv_plus_frames_N3 = umap.UMAP(n_components=3, random_state=42).fit_transform(np.hstack([invariants_ND, learned_frames_N9]))
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_plus_frames_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_learned_inv_plus_frames_N3)

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(16, 12))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(lower_dim_learned_inv_plus_frames_N3[:, 0], lower_dim_learned_inv_plus_frames_N3[:, 1], lower_dim_learned_inv_plus_frames_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=ELEV, azim=AZIM)
        plt.tight_layout()
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_plus_frames_3D%s-elev=%.1f-azim=%d-split=%s-input_type=%s.png' % (model_type_str, ELEV, AZIM, args.split, args.input_type)))
        plt.close()


        ## plot individual vectors of learned frames (only with NRNR arbitrary human frame)
        print('Plotting learned frames in 3D...')
        x_learned, y_learned, z_learned = learned_frames_N9[:, :3], learned_frames_N9[:, 3:6], learned_frames_N9[:, -6:]

        fig = plt.figure(figsize=(24, 8))

        ax = fig.add_subplot(1, 3, 1, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-60)
        ax.legend()

        ax = fig.add_subplot(1, 3, 2, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-30)
        ax.legend()

        ax = fig.add_subplot(1, 3, 3, projection='3d')
        ax.scatter(x_learned[:, 0], x_learned[:, 1], x_learned[:, 2], label='1st l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], label='2nd l1 vectors', s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=0)
        ax.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/learned_frames_3d_vectors%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()


        ## plot individual second vector of learned frames (only with NRNR arbitrary human frame)
        print('Plotting second l1 vector of learned frames in 3D...')
        x_learned, y_learned, z_learned = learned_frames_N9[:, :3], learned_frames_N9[:, 3:6], learned_frames_N9[:, -6:]

        colors_10 = plt.get_cmap('tab10').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(24, 8))

        ax = fig.add_subplot(1, 3, 1, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-60)

        ax = fig.add_subplot(1, 3, 2, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=-30)

        ax = fig.add_subplot(1, 3, 3, projection='3d')
        ax.scatter(y_learned[:, 0], y_learned[:, 1], y_learned[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=10.0, azim=0)

        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/learned_second_l1_3d_vectors%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()




    ## umap of invariants for valid/test data fitted to training data
    if args.split != 'train':
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
            lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
        else:
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train-input_type=%s.npz' % (model_type_str, args.input_type)))
            invariants_train_MD = arrays_train['invariants_ND']
            lower_dim_invariants_N2 = umap.UMAP(random_state=42).fit(invariants_train_MD).transform(invariants_ND)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_invariants_N2)

        if 'mnist' in args.model_dir:
            colors_10 = plt.get_cmap('tab10').colors
        else:
            colors_10 = plt.get_cmap('tab20').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        if 'mnist' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
        else:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

        plt.figure(figsize=(12, 8))
        plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_fitted_on_train%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()


    ## umap of invariants for valid/test data fitted to training data, in 3D
    if args.split != 'train':
        if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
            lower_dim_invariants_N3 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))
        else:
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train-input_type=%s.npz' % (model_type_str, args.input_type)))
            invariants_train_MD = arrays_train['invariants_ND']
            lower_dim_invariants_N3 = umap.UMAP(n_components=3, random_state=42).fit(invariants_train_MD).transform(invariants_ND)
            np.save(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants_fitted_on_train_3D%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)), lower_dim_invariants_N3)

        if 'mnist' in args.model_dir:
            colors_10 = plt.get_cmap('tab10').colors
        else:
            colors_10 = plt.get_cmap('tab20').colors
        colors_N = list(map(lambda i: colors_10[i], labels_N))

        from matplotlib.lines import Line2D
        if 'mnist' in args.model_dir:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
        else:
            legend_elements = [Line2D([0], [0], marker='o', ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%s' % (ind_to_aa[label])) for label in list(sorted(list(set(list(labels_N)))))]

        fig = plt.figure(figsize=(16, 12))
        ax = fig.add_subplot(projection='3d')
        ax.scatter(lower_dim_invariants_N3[:, 0], lower_dim_invariants_N3[:, 1], lower_dim_invariants_N3[:, 2], c=colors_N, s=(mlp.rcParams['lines.markersize']*MARKER_SCALING)**2)
        ax.view_init(elev=ELEV, azim=AZIM)
        plt.legend(handles=legend_elements)
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_fitted_on_train_3D%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
        plt.close()
    

    ## linear classification in the latent space
    if args.split != 'train':
        output_dir = os.path.join(args.model_dir, args.hash, 'latent_space_classification/classificaton_on_latent_space_default_classes%s-split=%s-input_type=%s.csv' % (model_type_str, args.split, args.input_type))
        if not os.path.exists(output_dir):
            arrays_train = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=train-input_type=%s.npz' % (model_type_str, args.input_type)))
            train_invariants_ND = arrays_train['invariants_ND']
            train_labels_N = arrays_train['labels_N']

            arrays_valid = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
            valid_invariants_ND = arrays_valid['invariants_ND']
            valid_labels_N = arrays_valid['labels_N']

            report = latent_space_prediction(train_invariants_ND, train_labels_N, valid_invariants_ND, valid_labels_N, classifier='LR', optimize_hyps=False)
            pd.DataFrame(report).to_csv(output_dir)


    if 'mnist' in args.model_dir:
        n_clusters = 10
    elif 'zernicke' in args.model_dir:
        n_clusters = 20
    else:
        raise NotImplementedError()
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=args.seed, verbose=0)

    labels_pred_N = kmeans.fit_predict(invariants_ND)

    homogeneity = homogeneity_score(labels_N, labels_pred_N)
    completeness = completeness_score(labels_N, labels_pred_N)
    purity = purity_score(labels_N, labels_pred_N)
    silhouette = silhouette_score(invariants_ND, labels_pred_N)

    table = {
        'Homogeneity': [homogeneity],
        'Completeness': [completeness],
        'Purity': [purity],
        'Silhouette': [silhouette]
    }

    pd.DataFrame(table).to_csv(os.path.join(args.model_dir, args.hash, 'latent_space_classification/quality_of_clustering_metrics_default_classes%s-split=%s-input_type=%s.csv' % (model_type_str, args.split, args.input_type)), index=None)
