
'''
Compute the sum of signals of specified labels, to show the overlap between them
'''

import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')

import cgnet_fibers
from projections import real_sph_ift, ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor
from utils.data_getter import get_data_mnist, get_grids_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.protein import *
from utils.data_utils import MNISTDatasetWithConditioning

from loss_functions import *

import umap
import matplotlib as mlp

from torch import Tensor

def make_vector(x: Dict[int, Tensor]):
    x_vec = []
    for l in sorted(list(x.keys())):
        x_vec.append(x[l].reshape((x[l].shape[0], -1)))
    return torch.cat(x_vec, dim=-1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_type', type=str, default='NRR-avg_sqrt_power')
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--model_type', type=str, default='lowest_total_loss_with_final_kl_model') # lowest_total_loss_with_final_kl_model, lowest_rec_loss
    parser.add_argument('--split', type=str, default='test')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--comma_sep_labels', type=str, default='0,1,2,3,4,5,6,7,8,9')
    parser.add_argument('--seed', type=int, default=1000005) 

    args = parser.parse_args()

    print(args.hash, file=sys.stderr)
    print(args.comma_sep_labels, file=sys.stderr)

    if args.split == 'train':
        MARKER_SCALING = 0.2
    else:
        MARKER_SCALING = 0.75
    
    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    device = 'cpu'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)
    

    if 'mnist' in args.model_dir:
        data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)
    elif 'zernicke' in args.model_dir:
        # filter by desired lmax and channels
        OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
        rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
        mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
        data_irreps = o3.Irreps(str(mul_rst))

    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in data_irreps.ls])

    if os.path.exists(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.pkl' % (model_type_str, args.split, args.input_type, args.comma_sep_labels))):
        with open(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.pkl' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)), 'rb') as f:
            reconstructions_in_canonical_frame = pickle.load(f)
    else:
        ## get w3j matrices
        with gzip.open(args.w3j_filepath, 'rb') as f:
            w3j_matrices = pickle.load(f)

        for key in w3j_matrices:
            if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
                if device is not None:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
                else:
                    w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
                w3j_matrices[key].requires_grad = False
                    
        cgvae = cgnet_fibers.ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['net_lmax'],
                                            hparams['n_cg_blocks'],
                                            list(map(int, hparams['ch_size_list'].split(','))),
                                            hparams['ls_nonlin_rule_list'].split(','),
                                            hparams['ch_nonlin_rule_list'].split(','),
                                            hparams['do_initial_linear_projection'],
                                            hparams['ch_initial_linear_projection'],
                                            w3j_matrices,
                                            device,
                                            lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                            use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            norm_type=hparams['norm_type'], # None, layer, signal
                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                            norm_balanced=hparams['norm_balanced'],
                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                            norm_location=hparams['norm_location'], # first, between, last
                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                            learn_frame=hparams['learn_frame'],
                                            is_vae=hparams['is_vae']).to(device)


        cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
        cgvae = cgvae.float()
        cgvae.eval()

        orig_grid, xyz_grid = get_grids_mnist()
        
        arrays = np.load(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, args.split, args.input_type)))
        labels_N = arrays['labels_N']
        invariants_ND = arrays['invariants_ND']
        learned_frames_N9 = arrays['learned_frames_N9']
        # rotations_N9 = arrays['rotations_N9']
        images_NF = torch.tensor(arrays['images_NF'])
        rec_images_NF = torch.tensor(arrays['rec_images_NF'])
        N = images_NF.shape[0]

        labels_to_show = list(sorted(list(map(int, args.comma_sep_labels.split(',')))))
        mask = (labels_N == labels_to_show[0])
        for label in labels_to_show[1:]:
            mask = np.logical_or(mask, labels_N == label)
        invariants_ND = invariants_ND[mask]

        print('%d datapoints' % (np.sum(mask)), file=sys.stderr)

        reconstructions_in_canonical_frame = []
        batch_size = 2000
        import math
        n_batches = math.floor(invariants_ND.shape[0] / batch_size) + 1
        for n in tqdm(range(n_batches)):
            z = torch.tensor(invariants_ND[n*batch_size : (n+1)*batch_size]).float().to(device)
                
            if z.shape[0] == 0:
                continue
            
            frame = torch.tensor(np.tile(np.eye(3), (z.shape[0], 1, 1))).float().to(device)
            x_reconst, scalar_features_reconst = cgvae.decode(z, frame)
            x_rec_inv = real_sph_ift(make_vector(x_reconst).detach(), orig_grid, max(data_irreps.ls)).squeeze().numpy()
            reconstructions_in_canonical_frame.append(x_rec_inv)

        print(np.vstack(reconstructions_in_canonical_frame).shape, file=sys.stderr)

        reconstructions_in_canonical_frame = np.mean(np.vstack(reconstructions_in_canonical_frame), axis=0)

        if not os.path.exists(os.path.join(args.model_dir, args.hash, 'sum_of_images')):
            os.mkdir(os.path.join(args.model_dir, args.hash, 'sum_of_images'))

        with open(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s.pkl' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)), 'wb') as f:
            pickle.dump(reconstructions_in_canonical_frame, f)
            

    plt.imshow(reconstructions_in_canonical_frame.reshape(60, 60), cmap='viridis')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s-viridis.png' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)))
    plt.savefig(os.path.join(args.model_dir, args.hash, 'sum_of_images', 'rec_images_unrotated_by_inverse_of_frame%s-split=%s-input_type=%s-labels=%s-viridis.pdf' % (model_type_str, args.split, args.input_type, args.comma_sep_labels)))
    plt.close()


