
'''
Runs requested model on requested splits of data in inference mode.
Saves invariants, frames (for equivariant models only), and reconstructed signals in the model's directory
as numpy arrays (they occupy less memory I think). these are all stored in parallel arrays that are parallel
to the input data.
'''


import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')

from cgnet_fibers import ClebschGordanVAE_symmetric_simple_flexible
from projections import real_sph_ift
from utils.data_getter import get_data_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.data_utils import MNISTDatasetWithConditioning__fibers
from projections import ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor

import umap

ZERNICKE_DATA_DIR = 'data/zernicke'
DATA_ARGS = ['rmax', 'lmax', 'n_channels', 'rcut', 'rst_normalization', 'mul_rst_normalization']

def dict_to_device(adict, device):
    for key in adict:
        adict[key] = adict[key].float().to(device)
    return adict

def make_vector(x):
    x_vec = []
    for l in sorted(list(x.keys())):
        x_vec.append(x[l].view((x[l].shape[0], -1)))
    return torch.cat(x_vec, dim=-1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--input_type', type=str, default='RR-avg_sqrt_power')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--splits', type=comma_sep_str_list, default='train,valid,test')
    parser.add_argument('--model_type', type=str)
    parser.add_argument('--hash', type=str)
    parser.add_argument('--seed', type=int, default=1000005)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on %s.' % device)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'
    

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    if 'mnist' in args.model_dir:
        data_irreps = o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)

        data, s2_data, (orig_grid, xyz_grid) = get_data_mnist(args.input_type, get_grids=True, get_s2=False, lmax=hparams['net_lmax'])

        train_data = data['train']['projections']
        train_rot = data['train']['rotations']
        train_labels = torch.tensor(data['train']['labels'])
        valid_data = data['valid']['projections']
        valid_rot = data['valid']['rotations']
        valid_labels = torch.tensor(data['valid']['labels'])
        test_data = data['test']['projections']
        test_rot = data['test']['rotations']
        test_labels = torch.tensor(data['test']['labels'])

        train_dataset = MNISTDatasetWithConditioning__fibers(train_data, data_irreps, train_labels, train_rot)
        valid_dataset = MNISTDatasetWithConditioning__fibers(valid_data, data_irreps, valid_labels, valid_rot)
        test_dataset = MNISTDatasetWithConditioning__fibers(test_data, data_irreps, test_labels, test_rot)

        train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=False, drop_last=True)
        valid_dataloader = DataLoader(valid_dataset, batch_size=hparams['batch_size'], shuffle=False, drop_last=True)
        test_dataloader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, drop_last=True)

        dataloaders = {
            'train': train_dataloader,
            'valid': valid_dataloader,
            'test': test_dataloader
        }

    elif 'zernicke' in args.model_dir:
        raise NotImplementedError()


    ## get w3j matrices
    with gzip.open(args.w3j_filepath, 'rb') as f:
        w3j_matrices = pickle.load(f)

    for key in w3j_matrices:
        if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
            if device is not None:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
            else:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
            w3j_matrices[key].requires_grad = False

    if 'do_final_signal_norm' not in hparams:
        if len(args.input_type.split('-')) == 1:
            hparams['do_final_signal_norm'] = True
        elif args.input_type.split('-')[1] == 'sqrt_power':
            hparams['do_final_signal_norm'] = True
        elif args.input_type.split('-')[1] in ['None', 'avg_sqrt_power']:
            hparams['do_final_signal_norm'] = False
        else:
            exit(1)

    if hparams['model'] == 'cgvae':
        raise NotImplementedError()
    elif hparams['model'] == 'cgvae_symmetric':

        if hparams['cg_width_type'] == 'constant':
            encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        elif hparams['cg_width_type'] == 'reduce_with_l':
            encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
        else:
            raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

        if hparams['cg_width_type'] == 'constant':
            decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        elif hparams['cg_width_type'] == 'reduce_with_l':
            decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
        else:
            raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))


        if 'learn_frame' not in hparams:
            hparams['learn_frame'] = False

        if 'teacher_forcing' not in hparams:
            hparams['teacher_forcing'] = False

        if 'ch_nonlin_rule' not in hparams:
            hparams['ch_nonlin_rule'] = 'full'

        if 'lmax_list' not in hparams:
            hparams['lmax_list'] = None

        for key, value in hparams.items():
            print('{}\t{}'.format(key, value))
        print()
        cgvae = ClebschGordanVAE_symmetric(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['encoder_n_cg_blocks'],
                                            encoder_irreps_cg_hidden,
                                            w3j_matrices,
                                            device,
                                            n_reconstruction_layers=hparams['n_reconstruction_layers'],
                                            bottleneck_hidden_dims=list(map(int, hparams['bottleneck_hidden_dims'].split(','))), # in order for encoder, get reversed for decoder
                                            dropout_rate=0.0,
                                            nonlinearity_rule=hparams['nonlinearity_rule'],
                                            ch_nonlin_rule=hparams['ch_nonlin_rule'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            norm_type=hparams['norm_type'], # None, layer, signal
                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                            norm_location=hparams['norm_location'], # first, between, last
                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                            sf_rec_loss_fn=hparams['sf_rec_loss_fn'], # mse, cosine
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                            teacher_forcing=bool(hparams['teacher_forcing']),
                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                            learn_frame=hparams['learn_frame']).to(device)
    elif hparams['model'] == 'cgvae_symmetric_simple':

        if hparams['cg_width_type'] == 'constant':
            encoder_irreps_cg_hidden = (hparams['encoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        elif hparams['cg_width_type'] == 'reduce_with_l':
            encoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['encoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
        else:
            raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))

        if hparams['cg_width_type'] == 'constant':
            decoder_irreps_cg_hidden = (hparams['decoder_hidden_dim'] * o3.Irreps.spherical_harmonics(hparams['net_lmax'], 1)).sort().irreps.simplify()
        elif hparams['cg_width_type'] == 'reduce_with_l':
            decoder_irreps_cg_hidden = o3.Irreps('+'.join(['%dx%de' % (hparams['decoder_hidden_dim'] // np.sqrt(2*l + 1), l) for l in range(hparams['net_lmax'] + 1)]))
        else:
            raise NotImplementedError('%s cg_width_type not implemented.' % (hparams['cg_width_type']))


        if 'learn_frame' not in hparams:
            hparams['learn_frame'] = False

        if 'teacher_forcing' not in hparams:
            hparams['teacher_forcing'] = False

        if 'ch_nonlin_rule' not in hparams:
            hparams['ch_nonlin_rule'] = 'full'

        for key, value in hparams.items():
            print('{}\t{}'.format(key, value))
        print()
        cgvae = ClebschGordanVAE_symmetric_simple(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['encoder_n_cg_blocks'],
                                            encoder_irreps_cg_hidden,
                                            w3j_matrices,
                                            device,
                                            nonlinearity_rule=hparams['nonlinearity_rule'],
                                            ch_nonlin_rule=hparams['ch_nonlin_rule'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            norm_type=hparams['norm_type'], # None, layer, signal
                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                            norm_location=hparams['norm_location'], # first, between, last
                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                            learn_frame=hparams['learn_frame']).to(device)
    elif hparams['model'] == 'cgvae_symmetric_simple_flexible':
        if 'norm_balanced' not in hparams:
            hparams['norm_balanced'] = False
        
        for key, value in hparams.items():
            print('{}\t{}'.format(key, value))
        print()
        cgvae = ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['net_lmax'],
                                            hparams['n_cg_blocks'],
                                            list(map(int, hparams['ch_size_list'].split(','))),
                                            hparams['ls_nonlin_rule_list'].split(','),
                                            hparams['ch_nonlin_rule_list'].split(','),
                                            hparams['do_initial_linear_projection'],
                                            hparams['ch_initial_linear_projection'],
                                            w3j_matrices,
                                            device,
                                            lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                            use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            norm_type=hparams['norm_type'], # None, layer, signal
                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                            norm_balanced=hparams['norm_balanced'],
                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                            norm_location=hparams['norm_location'], # first, between, last
                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                            learn_frame=hparams['learn_frame'],
                                            is_vae=hparams['is_vae']).to(device)

    if not args.model_type == 'no_training':
        cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
    cgvae = cgvae.float()
    cgvae.eval()


    for split, dataloader in zip(args.splits, [dataloaders[s] for s in args.splits]):
        invariants, labels, learned_frames, rotations, images, rec_images = [], [], [], [], [], []
        for i, (X, X_rec, y, rot) in tqdm(enumerate(dataloader)):
            X = dict_to_device(X, device)
            # rot = rot.view(-1, 3, 3).float().to(device)

            (z_mean, z_log_var), scalar_features, learned_frame = cgvae.encode(X)

            z = z_mean

            # z_stddev = torch.exp(0.5 * z_log_var) # takes exponential function (log var -> stddev)
            # epsilon = torch.randn_like(z_stddev).to(device)        # sampling epsilon
            # z = z_mean + z_stddev*epsilon    

            if hparams['learn_frame']:
                x_reconst, scalar_features_reconst = cgvae.decode(z, learned_frame)
            else:
                x_reconst, scalar_features_reconst = cgvae.decode(z, rot)

            invariants.append(z.detach().cpu().numpy())

            if hparams['learn_frame']:
                learned_frames.append(learned_frame.reshape(-1, 1, 9).squeeze(1).detach().cpu().numpy())
            else:
                learned_frames.append(rot.reshape(-1, 1, 9).squeeze(1).cpu().numpy())
            
            labels.append(y.cpu().numpy())
            # rotations.append(rot.reshape(-1, 1, 9).squeeze().cpu().numpy())
            images.append(make_vector(X).detach().cpu().numpy())
            rec_images.append(make_vector(x_reconst).detach().cpu().numpy())
            
        invariants_ND = np.vstack(invariants)
        learned_frames_N9 = np.vstack(learned_frames)
        labels_N = np.hstack(labels)
        rotations_N9 = np.array([0]) #np.vstack(rotations)
        images_NF = np.vstack(images)
        rec_images_NF = np.vstack(rec_images)

        if not os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays')):
            os.mkdir(os.path.join(args.model_dir, args.hash, 'results_arrays'))
        
        np.savez(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s-input_type=%s.npz' % (model_type_str, split, args.input_type)),
                        invariants_ND = invariants_ND,
                        learned_frames_N9 = learned_frames_N9,
                        labels_N = labels_N,
                        rotations_N9 = rotations_N9,
                        images_NF = images_NF,
                        rec_images_NF = rec_images_NF)
