import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')
import cgnet_fibers
from projections import real_sph_ift
from utils.data_getter import get_data_mnist, get_grids_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.data_utils import MNISTDatasetWithConditioning

import umap

from scipy.interpolate import interp1d

def dict_to_device(adict, device):
    if isinstance(adict, dict):
        for key in adict:
            adict[key] = adict[key].float().to(device)
    else:
        adict = adict.float().to(device)
    return adict

def make_dict(tensor, irreps):
    if isinstance(tensor, dict):
        return tensor
    else:
        batch_size = tensor.shape[0]
        ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        dict_repr = {}
        for l in sorted(list(set(irreps.ls))):
            dict_repr[l] = tensor[:, ls_indices == l].view(batch_size, -1, 2*l+1).double()
        return dict_repr

def make_vector(x):
    if isinstance(x, dict):
        x_vec = []
        for l in sorted(list(x.keys())):
            x_vec.append(x[l].view((x[l].shape[0], -1)))
        return torch.cat(x_vec, dim=-1)
    else:
        return x

IMAGE_SIZE = 6

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_class', type=str, default='fibers')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--input_type', type=str, default='NRR-avg_sqrt_power')
    parser.add_argument('--model_dir', type=str, default='../runs/mnist/local_equiv_fibers')
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--hash', type=str, required=True)
    parser.add_argument('--start_label', type=int, required=True)
    parser.add_argument('--end_label', type=int, required=True)
    parser.add_argument('--n_samples', type=int, default=5)
    parser.add_argument('--n_interpolations', type=int, default=10)
    parser.add_argument('--seed', type=int, default=1000000)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on %s.' % device)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)
    
    data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)

    if args.model_class == 'e3nn':

        if 'learn_frame' not in hparams:
            hparams['learn_frame'] = False
        
        if 'teacher_forcing' not in hparams:
            hparams['teacher_forcing'] = False

        # encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim']* o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))

        # decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim']* o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))

        if hparams['model'] == 'cgvae':
            cgvae = cgnet.ClebschGordanVAE(data_irreps,
                                        hparams['latent_dim'],
                                        hparams['net_lmax'],
                                        hparams['encoder_n_cg_blocks'],
                                        encoder_irreps_cg_hidden,
                                        hparams['decoder_n_cg_blocks'],
                                        decoder_irreps_cg_hidden,
                                        device=device,
                                        bottleneck_hidden_dims=list(map(lambda x: int(x), hparams['bottleneck_hidden_dims'].split(','))),
                                        softmax_before_sf_mse=hparams['softmax_before_sf_mse'],
                                        gate_nonlinearity_block=hparams['gate_nonlinearity_block'],
                                        n_reconstruction_layers=hparams['n_reconstruction_layers'],
                                        use_batch_norm=hparams['use_batch_norm'],
                                        nonlinearity_rule=hparams['nonlinearity_rule'],
                                        linearity_first=hparams['linearity_first'],
                                        signal_norm=hparams['signal_norm'],
                                        use_skip_connections_in_decoder=hparams['use_skip_connections_in_decoder'],
                                        sf_rec_loss_fn=hparams['sf_rec_loss_fn'],
                                        x_rec_loss_fn=hparams['x_rec_loss_fn'],
                                        learn_frame=hparams['learn_frame']).to(device)
        elif hparams['model'] == 'cgvae_symmetric':
            cgvae = cgnet.ClebschGordanVAE_symmetric(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['net_lmax'],
                                            hparams['encoder_n_cg_blocks'],
                                            encoder_irreps_cg_hidden,
                                            device=device,
                                            bottleneck_hidden_dims=bottleneck_hidden_dims,
                                            softmax_before_sf_mse=hparams['softmax_before_sf_mse'],
                                            gate_nonlinearity_block=hparams['gate_nonlinearity_block'],
                                            n_reconstruction_layers=hparams['n_reconstruction_layers'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            nonlinearity_rule=hparams['nonlinearity_rule'],
                                            linearity_first=hparams['linearity_first'],
                                            signal_norm=hparams['signal_norm'],
                                            use_skip_connections_in_decoder=hparams['use_skip_connections_in_decoder'],
                                            sf_rec_loss_fn=hparams['sf_rec_loss_fn'],
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'],
                                            learn_frame=hparams['learn_frame']).to(device)
    elif args.model_class == 'fibers':

        ## get w3j matrices
        with gzip.open(args.w3j_filepath, 'rb') as f:
            w3j_matrices = pickle.load(f)

        for key in w3j_matrices:
            if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
                if device is not None:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
                else:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
                w3j_matrices[key].requires_grad = False
        
        if hparams['model'] == 'cgvae':
            raise NotImplementedError()
        elif hparams['model'] == 'cgvae_symmetric':

            if 'learn_frame' not in hparams:
                hparams['learn_frame'] = False
            
            if 'teacher_forcing' not in hparams:
                hparams['teacher_forcing'] = False

            if hparams['cg_width_type'] == 'constant':
                encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

            if hparams['cg_width_type'] == 'constant':
                decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))
            
            cgvae = cgnet_fibers.ClebschGordanVAE_symmetric(data_irreps,
                                                hparams['latent_dim'],
                                                hparams['encoder_n_cg_blocks'],
                                                encoder_irreps_cg_hidden,
                                                w3j_matrices,
                                                device,
                                                n_reconstruction_layers=hparams['n_reconstruction_layers'],
                                                bottleneck_hidden_dims=list(map(int, hparams['bottleneck_hidden_dims'].split(','))), # in order for encoder, get reversed for decoder
                                                dropout_rate=0.0,
                                                nonlinearity_rule=hparams['nonlinearity_rule'],
                                                use_batch_norm=hparams['use_batch_norm'],
                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=hparams['norm_location'], # first, between, last
                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                sf_rec_loss_fn=hparams['sf_rec_loss_fn'], # mse, cosine
                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                teacher_forcing=bool(hparams['teacher_forcing']),
                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                learn_frame=hparams['learn_frame']).to(device)
        elif hparams['model'] == 'cgvae_symmetric_simple_flexible':
            for key, value in hparams.items():
                print('{}\t{}'.format(key, value))
            print()
            cgvae = cgnet_fibers.ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                                hparams['latent_dim'],
                                                hparams['net_lmax'],
                                                hparams['n_cg_blocks'],
                                                list(map(int, hparams['ch_size_list'].split(','))),
                                                hparams['ls_nonlin_rule_list'].split(','),
                                                hparams['ch_nonlin_rule_list'].split(','),
                                                hparams['do_initial_linear_projection'],
                                                hparams['ch_initial_linear_projection'],
                                                w3j_matrices,
                                                device,
                                                use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                                use_batch_norm=hparams['use_batch_norm'],
                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=hparams['norm_location'], # first, between, last
                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                learn_frame=hparams['learn_frame'],
                                                is_vae=hparams['is_vae']).to(device)

    cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
    cgvae = cgvae.float()
    cgvae.eval()

    orig_grid, xyz_grid = get_grids_mnist()


    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = arrays['images_NF']
    rec_images_NF = arrays['rec_images_NF']

    N = labels_N.shape[0]

    # make directories if they do not exists
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'reconstructions')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'reconstructions'))

    if 'NRNR' in args.input_type:
        torch.manual_seed(1000006) # this seed just has a nice rotation from NRNR frame to a frame that makes the number visible with the 2D projection
        nice_rot_matrix = e3nn.o3.rand_matrix(1).float().view(-1, 1, 9).squeeze().unsqueeze(0)
        nice_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), nice_rot_matrix[0].view(3, 3))


    ## Workflow:
    ## start_points --> pick args.n_samples starting datapoints of label arg.start_label
    ## end_points --> pick args.n_samples starting datapoints of label arg.end_label
    ## Get the invariants of all points, and (linearly) interpolate between each of the start_points to each of the end_points.
    ## Plot the reconstructions from left (start_points) to right (end_points)
    ## Also plot the original images --> maybe only select images with reconstruction loss below a certain threshold,
    ## to show prettier reconstructions
    '''
    Issue: interpolating in the invariants' latent space automatically changes the orientation of the reocnstructions,
    which makes it hard for visualization purposes, since everything reconstruction is going to be located somewhere else
    in on the sphere. This is not a bug, it's a feature: the invariant latent space is agnostic to the orientation of
    signals on the sphere.

    To start, I am going to plot every trajectory twice: on the learned frames of the starting images, and on the learned
    frames of the ending images. Hopefully there's going to be something distinguishable between the two sets of plots.
    Ideas (not necessarily joint) to make it better in the future:
        - Make 3D trajectory plots, oriented  in such a way as to see what's going on through the entire trajectory
        - Interpolate in the frame space --> vector/angle interpolation
            --> goal: every digit reconstructed on the same location on the sphere --> probably not gonna happen since
                      there are no constraints on the learning of frames and how they relate to the invariants.
        - Rotate everything so that the center of mass is on a predefined location that is nice to visualize in the 2D projection.
          Other rotations would be arbitrary and it may cause some confusion, but it would be something semantically meaningful at least.
    '''

    rng = np.random.default_rng(args.seed)
    idxs = np.arange(N)
    M = args.n_samples
    D = invariants_ND.shape[1]
    I = args.n_interpolations
    start_label_idxs = rng.choice(idxs[labels_N == args.start_label], args.n_samples)
    end_label_idxs = rng.choice(idxs[labels_N == args.end_label], args.n_samples)

    start_latent_MD = invariants_ND[start_label_idxs]
    end_latent_MD = invariants_ND[end_label_idxs]

    start_frames_M9 = learned_frames_N9[start_label_idxs]
    end_frames_M9 = learned_frames_N9[end_label_idxs]

    orig_start_images = images_NF[start_label_idxs]
    orig_end_images = images_NF[end_label_idxs]

    trajectories_MID = []
    for m in range(M):
        curr_interpolation = []
        for d in range(D):
            x = np.array([0.0, 1.0])
            y = np.array([start_latent_MD[m, d], end_latent_MD[m, d]])
            fun = interp1d(x, y)
            interp_x = np.linspace(0.0, 1.0, args.n_interpolations)
            curr_interpolation.append(fun(interp_x))
        trajectories_MID.append(np.transpose(np.vstack(curr_interpolation)))

    if args.model_class == 'e3nn':
        model_class_repeat = (I, 1)
        model_class_view = (-1, 9)
    elif args.model_class == 'fibers':
        model_class_repeat = (I, 1, 1)
        model_class_view = (-1, 3, 3)


    # ## two trajectories in different frames
    # nrows = M*2
    # ncols = I+2
    # fig, axs = plt.subplots(figsize=(ncols*IMAGE_SIZE, nrows*IMAGE_SIZE), ncols=ncols, nrows=nrows, sharex=True, sharey=True)

    # for m, trajectories_ID in enumerate(trajectories_MID):

    #     reconstructions_start_frame_IF = make_vector(cgvae.decode(torch.tensor(trajectories_ID).float().to(device), torch.tensor(start_frames_M9[m]).repeat(model_class_repeat).view(model_class_view).to(device))[0]).detach().cpu()
    #     if args.input_type == 'NRNR':
    #         # rotate every image in what I found to be a good location on the sphere for the purposes of visualization
    #         recs_inv_start_frame_IG = real_sph_ift(rotate_signal(reconstructions_start_frame_IF, data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
    #     else:
    #         # rotate every image to the learned identity frame
    #         start_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(start_frames_M9[m]).view(3, 3)) # torch.transpose(torch.tensor(start_frames_M9[m]).view(3, 3), 0, 1))
    #         recs_inv_start_frame_IG = real_sph_ift(rotate_signal(reconstructions_start_frame_IF, data_irreps, start_wigner), orig_grid, max(data_irreps.ls))


    #     reconstructions_end_frame_IF = make_vector(cgvae.decode(torch.tensor(trajectories_ID).float().to(device), torch.tensor(end_frames_M9[m]).repeat(model_class_repeat).view(model_class_view).to(device))[0]).detach().cpu()
    #     if args.input_type == 'NRNR':
    #         # rotate every image in what I found to be a good location on the sphere for the purposes of visualization
    #         recs_inv_end_frame_IG = real_sph_ift(rotate_signal(reconstructions_end_frame_IF, data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
    #     else:
    #         # rotate every image to the learned identity frame
    #         end_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(end_frames_M9[m]).view(3, 3)) # torch.transpose(torch.tensor(end_frames_M9[m]).view(3, 3), 0, 1))
    #         recs_inv_end_frame_IG = real_sph_ift(rotate_signal(reconstructions_end_frame_IF, data_irreps, end_wigner), orig_grid, max(data_irreps.ls))

    #     if args.input_type == 'NRNR':
    #         orig_start_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_start_images[m]).unsqueeze(0), data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
    #         orig_end_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_end_images[m]).unsqueeze(0), data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
    #     else:
    #         orig_start_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_start_images[m]).unsqueeze(0), data_irreps, start_wigner), orig_grid, max(data_irreps.ls))
    #         orig_end_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_end_images[m]).unsqueeze(0), data_irreps, end_wigner), orig_grid, max(data_irreps.ls))

    #     axs[2*m][0].imshow(orig_start_inv_1G[0].reshape(60, 60))
    #     axs[2*m+1][I+1].imshow(orig_end_inv_1G[0].reshape(60, 60))

    #     for i in range(I):
    #         axs[2*m][i+1].imshow(recs_inv_start_frame_IG[i].reshape(60, 60))
    #         axs[2*m+1][i+1].imshow(recs_inv_end_frame_IG[i].reshape(60, 60))
    
    # plt.tight_layout()
    # plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/trajectories_two_frames%s-start=%d-end=%d-seed=%d-split=%s-input_type=%s.png' % (model_type_str, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    # plt.close()


    ## one trajectory in the same identity frame
    nrows = M
    ncols = I
    fig, axs = plt.subplots(figsize=(ncols*IMAGE_SIZE, nrows*IMAGE_SIZE), ncols=ncols, nrows=nrows, sharex=True, sharey=True)

    for m, trajectories_ID in enumerate(trajectories_MID):

        reconstructions_id_frame_IF = make_vector(cgvae.decode(torch.tensor(trajectories_ID).float().to(device), torch.eye(3).unsqueeze(0).repeat(model_class_repeat).view(model_class_view).to(device))[0]).detach().cpu()
        if args.input_type == 'NRNR':
            # rotate every image in what I found to be a good location on the sphere for the purposes of visualization
            recs_inv_id_frame_IG = real_sph_ift(rotate_signal(reconstructions_id_frame_IF, data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
        else:
            # somehow, no rotation needed in the learned_frame case with RR data...
            recs_inv_id_frame_IG = real_sph_ift(reconstructions_id_frame_IF, orig_grid, max(data_irreps.ls))

        for i in range(I):
            if M == 1:
                axs[i].imshow(recs_inv_id_frame_IG[i].reshape(60, 60))
                axs[i].axis('off')
            else:
                axs[m][i].imshow(recs_inv_id_frame_IG[i].reshape(60, 60))
                axs[m][i].axis('off')
    
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/trajectories%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.png' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/trajectories%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.pdf' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.close()
    

    nrows = M
    ncols = 1
    fig, axs = plt.subplots(figsize=(ncols*IMAGE_SIZE, nrows*IMAGE_SIZE), ncols=ncols, nrows=nrows, sharex=True, sharey=True)
    for m in range(M):
        if args.input_type == 'NRNR':
            orig_start_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_start_images[m]).unsqueeze(0), data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
        else:
            start_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(start_frames_M9[m]).view(3, 3)) # torch.transpose(torch.tensor(start_frames_M9[m]).view(3, 3), 0, 1))
            orig_start_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_start_images[m]).unsqueeze(0), data_irreps, start_wigner), orig_grid, max(data_irreps.ls))

        if M == 1:
            axs.imshow(orig_start_inv_1G[0].reshape(60, 60))
            axs.axis('off')
        else:
            axs[m].imshow(orig_start_inv_1G[0].reshape(60, 60))
            axs[m].axis('off')

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/start%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.png' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/start%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.pdf' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.close()


    nrows = M
    ncols = 1
    fig, axs = plt.subplots(figsize=(ncols*IMAGE_SIZE, nrows*IMAGE_SIZE), ncols=ncols, nrows=nrows, sharex=True, sharey=True)
    for m in range(M):
        if args.input_type == 'NRNR':
            orig_end_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_end_images[m]).unsqueeze(0), data_irreps, nice_wigner), orig_grid, max(data_irreps.ls))
        else:
            end_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(end_frames_M9[m]).view(3, 3)) # torch.transpose(torch.tensor(start_frames_M9[m]).view(3, 3), 0, 1))
            orig_end_inv_1G = real_sph_ift(rotate_signal(torch.tensor(orig_end_images[m]).unsqueeze(0), data_irreps, end_wigner), orig_grid, max(data_irreps.ls))
        
        if M == 1:
            axs.imshow(orig_end_inv_1G[0].reshape(60, 60))
            axs.axis('off')
        else:
            axs[m].imshow(orig_end_inv_1G[0].reshape(60, 60))
            axs[m].axis('off')

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/end%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.png' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/end%s-n_samples=%d-start=%d-end=%d-seed=%d-split=%s-input_type=%s.pdf' % (model_type_str, args.n_samples, args.start_label, args.end_label, args.seed, args.split, args.input_type)))
    plt.close()

