
import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')
import cgnet_fibers
from projections import real_sph_ift
from utils.data_getter import get_data_mnist, get_grids_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.data_utils import MNISTDatasetWithConditioning

import umap


def dict_to_device(adict, device):
    if isinstance(adict, dict):
        for key in adict:
            adict[key] = adict[key].float().to(device)
    else:
        adict = adict.float().to(device)
    return adict

def make_dict(tensor, irreps):
    if isinstance(tensor, dict):
        return tensor
    else:
        batch_size = tensor.shape[0]
        ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in irreps.ls])
        dict_repr = {}
        for l in sorted(list(set(irreps.ls))):
            dict_repr[l] = tensor[:, ls_indices == l].view(batch_size, -1, 2*l+1).double()
        return dict_repr

def make_vector(x):
    if isinstance(x, dict):
        x_vec = []
        for l in sorted(list(x.keys())):
            x_vec.append(x[l].view((x[l].shape[0], -1)))
        return torch.cat(x_vec, dim=-1)
    else:
        return x


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_class', type=str, default='fibers')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--input_type', type=str, default='NRR-avg_sqrt_power')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--n_samples', type=int, default=4)
    parser.add_argument('--seed', type=int, default=88888)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on %s.' % device)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)

    if args.model_class == 'e3nn':
        pass
    elif args.model_class == 'fibers':

        ## get w3j matrices
        with gzip.open(args.w3j_filepath, 'rb') as f:
            w3j_matrices = pickle.load(f)

        for key in w3j_matrices:
            if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
                if device is not None:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
                else:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
                w3j_matrices[key].requires_grad = False
        
        if hparams['model'] == 'cgvae':
            raise NotImplementedError()
        elif hparams['model'] == 'cgvae_symmetric':

            if 'learn_frame' not in hparams:
                hparams['learn_frame'] = False
            
            if 'teacher_forcing' not in hparams:
                hparams['teacher_forcing'] = False
            
            if 'ch_nonlin_rule' not in hparams:
                hparams['ch_nonlin_rule'] = 'full'

            if hparams['cg_width_type'] == 'constant':
                encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

            if hparams['cg_width_type'] == 'constant':
                decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

            cgvae = cgnet_fibers.ClebschGordanVAE_symmetric(data_irreps,
                                                hparams['latent_dim'],
                                                hparams['encoder_n_cg_blocks'],
                                                encoder_irreps_cg_hidden,
                                                w3j_matrices,
                                                device,
                                                n_reconstruction_layers=hparams['n_reconstruction_layers'],
                                                bottleneck_hidden_dims=list(map(int, hparams['bottleneck_hidden_dims'].split(','))), # in order for encoder, get reversed for decoder
                                                dropout_rate=0.0,
                                                nonlinearity_rule=hparams['nonlinearity_rule'],
                                                ch_nonlin_rule=hparams['ch_nonlin_rule'],
                                                use_batch_norm=hparams['use_batch_norm'],
                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=hparams['norm_location'], # first, between, last
                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                sf_rec_loss_fn=hparams['sf_rec_loss_fn'], # mse, cosine
                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                teacher_forcing=bool(hparams['teacher_forcing']),
                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                learn_frame=hparams['learn_frame']).to(device)
        elif hparams['model'] == 'cgvae_symmetric_simple':

            if hparams['cg_width_type'] == 'constant':
                encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

            if hparams['cg_width_type'] == 'constant':
                decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
            elif hparams['cg_width_type'] == 'reduce_with_l':
                decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
            else:
                raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

            cgvae = cgnet_fibers.ClebschGordanVAE_symmetric_simple(data_irreps,
                                                hparams['latent_dim'],
                                                hparams['encoder_n_cg_blocks'],
                                                encoder_irreps_cg_hidden,
                                                w3j_matrices,
                                                device,
                                                nonlinearity_rule=hparams['nonlinearity_rule'],
                                                ch_nonlin_rule=hparams['ch_nonlin_rule'],
                                                use_batch_norm=hparams['use_batch_norm'],
                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=hparams['norm_location'], # first, between, last
                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                learn_frame=hparams['learn_frame']).to(device)
        elif hparams['model'] == 'cgvae_symmetric_simple_flexible':
            if 'norm_balanced' not in hparams:
                hparams['norm_balanced'] = False
                
            cgvae = cgnet_fibers.ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                                hparams['latent_dim'],
                                                hparams['net_lmax'],
                                                hparams['n_cg_blocks'],
                                                list(map(int, hparams['ch_size_list'].split(','))),
                                                hparams['ls_nonlin_rule_list'].split(','),
                                                hparams['ch_nonlin_rule_list'].split(','),
                                                hparams['do_initial_linear_projection'],
                                                hparams['ch_initial_linear_projection'],
                                                w3j_matrices,
                                                device,
                                                lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                                use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                                use_batch_norm=hparams['use_batch_norm'],
                                                norm_type=hparams['norm_type'], # None, layer, signal
                                                normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                norm_balanced=hparams['norm_balanced'],
                                                norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                norm_location=hparams['norm_location'], # first, between, last
                                                linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                do_final_signal_norm=hparams['do_final_signal_norm'],
                                                learn_frame=hparams['learn_frame'],
                                                is_vae=hparams['is_vae']).to(device)


    cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
    cgvae = cgvae.float()
    cgvae.eval()

    orig_grid, xyz_grid = get_grids_mnist()

    arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = arrays['images_NF']
    rec_images_NF = arrays['rec_images_NF']

    N = labels_N.shape[0]

    # make directories if they do not exists
    if not os.path.exists(os.path.join(args.model_dir, args.hash, 'reconstructions')):
        os.mkdir(os.path.join(args.model_dir, args.hash, 'reconstructions'))


    IMAGE_SIZE = 4

    for label in sorted(list(set(list(labels_N)))):
        idxs_label = np.arange(labels_N.shape[0])[labels_N == label]
        idxs_to_show = np.random.default_rng(args.seed).choice(idxs_label, size=args.n_samples)

        imgs = images_NF[idxs_to_show]
        imgs_rec = rec_images_NF[idxs_to_show]

        frames = learned_frames_N9[idxs_to_show]
        temp_imgs = []
        temp_imgs_rec = []
        for i in range(args.n_samples):
            undo_rot_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(frames[i]).view(3, 3))
            temp_imgs.append(real_sph_ift(rotate_signal(torch.tensor(imgs[i]).unsqueeze(0), data_irreps, undo_rot_wigner), orig_grid, max(data_irreps.ls))[0])
            temp_imgs_rec.append(real_sph_ift(rotate_signal(torch.tensor(imgs_rec[i]).unsqueeze(0), data_irreps, undo_rot_wigner), orig_grid, max(data_irreps.ls))[0])
        imgs = temp_imgs
        imgs_rec = temp_imgs_rec

        fig, axs = plt.subplots(figsize=(args.n_samples*IMAGE_SIZE, 2*IMAGE_SIZE), ncols=args.n_samples, nrows=2, sharex=True, sharey=True)

        for i, (img, img_rec) in enumerate(zip(imgs, imgs_rec)):
            axs[0][i].imshow(img.reshape(60, 60))
            axs[1][i].imshow(img_rec.reshape(60, 60))
        
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/example_reconstructions%s-label=%d-split=%s-input_type=%s-seed=%d.png' % (model_type_str, label, args.split, args.input_type, args.seed)))
        plt.close()


    for label in sorted(list(set(list(labels_N)))):
        idxs_label = np.arange(labels_N.shape[0])[labels_N == label]
        idxs_to_show = np.random.default_rng(args.seed).choice(idxs_label, size=args.n_samples)

        imgs = images_NF[idxs_to_show]
        imgs_rec = rec_images_NF[idxs_to_show]

        frames = learned_frames_N9[idxs_to_show]
        temp_imgs = []
        temp_imgs_rec = []
        for i in range(args.n_samples):
            undo_rot_wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), torch.tensor(frames[i]).view(3, 3))
            temp_imgs.append(real_sph_ift(rotate_signal(torch.tensor(imgs[i]).unsqueeze(0), data_irreps, undo_rot_wigner), orig_grid, max(data_irreps.ls))[0])
            temp_imgs_rec.append(real_sph_ift(rotate_signal(torch.tensor(imgs_rec[i]).unsqueeze(0), data_irreps, undo_rot_wigner), orig_grid, max(data_irreps.ls))[0])
        imgs = temp_imgs
        imgs_rec = temp_imgs_rec

        for i, (img, img_rec) in enumerate(zip(imgs, imgs_rec)):
            plt.figure(figsize=(IMAGE_SIZE, IMAGE_SIZE))
            plt.imshow(img.reshape(60, 60))
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/__true_image-n=%d-label=%d-split=%s-seed=%d.png' % (i, label, args.split, args.seed)))
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/__true_image-n=%d-label=%d-split=%s-seed=%d.pdf' % (i, label, args.split, args.seed)))
            plt.close()

            plt.figure(figsize=(IMAGE_SIZE, IMAGE_SIZE))
            plt.imshow(img_rec.reshape(60, 60))
            plt.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/__rec_image-n=%d-label=%d-split=%s-seed=%d.png' % (i, label, args.split, args.seed)))
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/__rec_image-n=%d-label=%d-split=%s-seed=%d.pdf' % (i, label, args.split, args.seed)))
            plt.close()
    



    ## example reconstruction of test data, with visual equivariance tests
    print('Plotting some reconstructions...')

    # (0) get some test signal
    torch.manual_seed(args.seed)
    idxs = torch.randint(N, size=(args.n_samples,))
    signal_orig = torch.tensor(images_NF)[idxs].float()

    # get some random rotation
    rot_matrix = e3nn.o3.rand_matrix(1).float().view(-1, 1, 9).squeeze().unsqueeze(0)
    wigner = get_wigner_D_from_rot_matrix(max(data_irreps.ls), rot_matrix[0].view(3, 3))

    # (1) rotate original signal
    signal_rot = rotate_signal(signal_orig, data_irreps, wigner)

    (z_mean_orig, z_log_var), scalar_features, learned_frame_orig = cgvae.encode(dict_to_device(make_dict(signal_orig, data_irreps), device))
    (z_mean_rot, z_log_var_rot), scalar_features_rot, learned_frame_rot = cgvae.encode(dict_to_device(make_dict(signal_rot, data_irreps), device))

    # rotate learned frame of original input
    learned_frame_orig_rot = rotate_signal(learned_frame_orig.reshape(-1, 1, 9).squeeze(1).detach().cpu(), o3.Irreps('3x1e'), wigner)

    if args.model_class == 'fibers':
        learned_frame_orig = learned_frame_orig.reshape(-1, 3, 3)
        learned_frame_rot = learned_frame_rot.reshape(-1, 3, 3)
        learned_frame_orig_rot = learned_frame_orig_rot.reshape(-1, 3, 3)


    # (2) original reconstruction
    x_reconst_orig, _ = cgvae.decode(z_mean_orig.to(device), learned_frame_orig.to(device))

    # (3) reconstruction with rotated input
    x_reconst_rot, _ = cgvae.decode(z_mean_rot.to(device), learned_frame_rot.to(device))
    
    # (4) original reconstruction, rotated afterwards
    x_reconst_orig_rot = rotate_signal(make_vector(x_reconst_orig).detach().cpu(), data_irreps, wigner)

    # (5) reconstruction with rotated input, but using latent vector of original signal (check that latents are invariant)
    x_reconst_rot_with_orig_z, _ = cgvae.decode(z_mean_orig.to(device), learned_frame_rot.to(device))

    # (6) reconstruction of original input, with rotated frame
    x_reconst_orig_rot_frame, _ = cgvae.decode(z_mean_orig.to(device), learned_frame_orig_rot.to(device))

    # (0) and (2) should be as close as possible
    # (3), (4), (5), (6) should be identical to one another up to numerical error, and as close as possible to (1)


    signal_orig_inv = real_sph_ift(make_vector(signal_orig).cpu().detach(), orig_grid, max(data_irreps.ls))
    signal_rot_inv = real_sph_ift(make_vector(signal_rot).cpu().detach(), orig_grid, max(data_irreps.ls))

    rec_orig_inv = real_sph_ift(make_vector(x_reconst_orig).cpu().detach(), orig_grid, max(data_irreps.ls))
    rec_rot_inv = real_sph_ift(make_vector(x_reconst_rot).cpu().detach(), orig_grid, max(data_irreps.ls))
    rec_orig_rot_inv = real_sph_ift(make_vector(x_reconst_orig_rot).cpu().detach(), orig_grid, max(data_irreps.ls))
    rec_rot_with_orig_z_inv = real_sph_ift(make_vector(x_reconst_rot_with_orig_z).cpu().detach(), orig_grid, max(data_irreps.ls))
    rec_orig_rot_frame_inv = real_sph_ift(make_vector(x_reconst_orig_rot_frame).cpu().detach(), orig_grid, max(data_irreps.ls))

    # reconstruction of rotated signal, for visual equivariance test
    fig, axs = plt.subplots(figsize=(20, 4*args.n_samples), ncols=5, nrows=args.n_samples, sharex=True, sharey=True)
    axs[0][0].set_title('Rotated\n(fwd/inv)')
    axs[0][1].set_title('Rotating input')
    axs[0][2].set_title('Rotating output')
    axs[0][3].set_title('Rotating input,\nz from unrotated')
    axs[0][4].set_title('Rotating frame')

    for i in range(args.n_samples):
        axs[i][0].imshow(signal_rot_inv[i].reshape(60, 60))
        axs[i][1].imshow(rec_rot_inv[i].reshape(60, 60))
        axs[i][2].imshow(rec_orig_rot_inv[i].reshape(60, 60))
        axs[i][3].imshow(rec_rot_with_orig_z_inv[i].reshape(60, 60))
        axs[i][4].imshow(rec_orig_rot_frame_inv[i].reshape(60, 60))

    plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/test_reconstructions_with_rotations%s-split=%s-input_type=%s-seed=%d.png' % (model_type_str, args.split, args.input_type, args.seed)))
    plt.close()


    ## plot some samples in the canonical frame
    ## sample some vectors in the latent space
    ## sample around the prior (zero mean, unit variance)
    if hparams['is_vae']:
        print('Plotting some samples...')

        ## initialize generator
        generator = torch.Generator().manual_seed(args.seed)
        
        z = torch.normal(torch.zeros((args.n_samples*args.n_samples, hparams['latent_dim'])), torch.ones((args.n_samples*args.n_samples, hparams['latent_dim'])), generator=generator)
        frame = torch.eye(3).repeat(args.n_samples*args.n_samples, 1).float().view(-1, 3, 3).squeeze().to(device)
        x_reconst, _ = cgvae.decode(z.to(device), frame.to(device))
        rec_inv_NF = real_sph_ift(make_vector(x_reconst).cpu().detach(), orig_grid, max(data_irreps.ls))


        for n in range(args.n_samples**2):
            fig, ax = plt.subplots(figsize=(4, 4), nrows=1, ncols=1)
            ax.imshow(rec_inv_NF[n].reshape(60, 60))
            ax.axis('off')
            plt.tight_layout()
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/sample_n%d-seed=%d.png' % (n, args.seed)))
            plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/sample_n%d-seed=%d.pdf' % (n, args.seed)))
            plt.close()
        
        fig, axs = plt.subplots(figsize=(4*args.n_samples, 4*args.n_samples), nrows=args.n_samples, ncols=args.n_samples, sharex=True, sharey=True)
        for n, ax in enumerate(axs.flatten()):
            ax.imshow(rec_inv_NF[n].reshape(60, 60))
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/samples-n_samples=%d-seed=%d.png' % (args.n_samples, args.seed)))
        plt.savefig(os.path.join(args.model_dir, args.hash, 'reconstructions/samples-n_samples=%d-seed=%d.pdf' % (args.n_samples, args.seed)))
        plt.close()
