
import numpy as np
import torch
from e3nn import o3
import matplotlib.pyplot as plt
import matplotlib as mpl
import gzip, pickle
import os, sys
import json
import argparse

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_type', type=str, default='NRR-avg_sqrt_power')
    parser.add_argument('--model_dir', type=str, default='../runs/mnist/local_equiv_fibers')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_model, lowest_rec_loss
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str, required=True)

    
    args = parser.parse_args()

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    device = 'cpu'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)

    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = torch.tensor(arrays['images_NF'])
    rec_images_NF = torch.tensor(arrays['rec_images_NF'])

    ## compute powers
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)
    orig_powers_N = torch.einsum('bf,bf,f->b', images_NF, images_NF, 1.0 / (2*ls_indices + 1)).numpy()
    rec_powers_N = torch.einsum('bf,bf,f->b', rec_images_NF, rec_images_NF, 1.0 / (2*ls_indices + 1)).numpy()

    N = labels_N.shape[0]


    ## umap of invariants
    if os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type))):
        lower_dim_invariants_N2 = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/umap_invariants%s-split=%s-input_type=%s.npy' % (model_type_str, args.split, args.input_type)))

    colors_10 = plt.get_cmap('tab10').colors
    colors_N = list(map(lambda i: colors_10[i], labels_N))

    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], marker='o', markersize=12.0, ls='', color=colors_10[label], markerfacecolor=colors_10[label], label='%d' % (label)) for label in list(sorted(list(set(list(labels_N)))))]
    # markersize = 9.0

    MARKER_SCALING = 0.5
    golden_ratio = (1 + 5.0**0.5) / 2
    height = 6.0
    width = height * golden_ratio

    # plt.figure(figsize=(width, height))
    # plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    # legend = plt.legend(handles=legend_elements, prop={'size': 14})

    # def export_legend(legend, filename="legend", expand=[-5,-5,5,5]):
    #     fig  = legend.figure
    #     fig.canvas.draw()
    #     bbox  = legend.get_window_extent()
    #     bbox = bbox.from_extents(*(bbox.extents + np.array(expand)))
    #     bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    #     fig.savefig(filename + '.png', dpi="figure", bbox_inches=bbox)
    #     fig.savefig(filename + '.pdf', dpi="figure", bbox_inches=bbox)

    # export_legend(legend, 'legend')
    # legend.remove()

    # plt.axis('off')
    # plt.tight_layout()
    # plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_invariants%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    # plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_invariants%s-split=%s-input_type=%s.pdf' % (model_type_str, args.split, args.input_type)))
    # plt.close()

    import pylab

    # create a figure for the data
    figData = pylab.figure(figsize=(width, height))
    ax = pylab.gca()

    pylab.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=colors_N, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)

    # create legend
    figLegend = pylab.figure(figsize = (1.2, 3.8)) # (1.0, 3.3)
    figLegend.legend(handles=legend_elements, loc='center', prop={'size': 17})
    figLegend.savefig('legend.png')
    figLegend.savefig('legend.pdf')

    ax.axis('off')
    figData.tight_layout()
    figData.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_invariants%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    figData.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_invariants%s-split=%s-input_type=%s.pdf' % (model_type_str, args.split, args.input_type)))
    legend = pylab.legend(handles=legend_elements, prop={'size': 12})
    figData.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_with_legend_invariants%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    figData.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/__pretty_umap_with_legend_invariants%s-split=%s-input_type=%s.pdf' % (model_type_str, args.split, args.input_type)))

    # # color by power - orig power
    # plt.figure(figsize=(12, 8))
    # plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=orig_powers_N, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_orig_power%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    # plt.close()

    # # color by power - rec power
    # plt.figure(figsize=(12, 8))
    # plt.scatter(lower_dim_invariants_N2[:, 0], lower_dim_invariants_N2[:, 1], c=rec_powers_N, s=(mpl.rcParams['lines.markersize']*MARKER_SCALING)**2)
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig(os.path.join(args.model_dir, args.hash, 'latent_space_viz/umap_invariants_by_rec_power%s-split=%s-input_type=%s.png' % (model_type_str, args.split, args.input_type)))
    # plt.close()

