
'''
Runs requested model on requested splits of data in inference mode.
Saves invariants, frames (for equivariant models only), and reconstructed signals in the model's directory
as numpy arrays (they occupy less memory I think). these are all stored in parallel arrays that are parallel
to the input data.
'''


import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')
from cgnet_fibers import ClebschGordanVAE_symmetric_simple_flexible
from projections import real_sph_ift
from utils.data_getter import get_data_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.data_utils import NeighborhoodsDatasetWithConditioning__fibers, NeighborhoodsDatasetWithConditioningAndIds__fibers
from projections import ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor

import umap

DATA_ARGS = ['rmax', 'lmax', 'n_channels', 'rcut', 'rst_normalization', 'get_H', 'get_SASA', 'get_charge']

def dict_to_device(adict, device):
    for key in adict:
        adict[key] = adict[key].float().to(device)
    return adict

def make_vector(x):
    x_vec = []
    for l in sorted(list(x.keys())):
        x_vec.append(x[l].reshape((x[l].shape[0], -1)))
    return torch.cat(x_vec, dim=-1)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--data_dir', type=str, default='../data/neighborhoods/data')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--splits', type=comma_sep_str_list, default='test')
    parser.add_argument('--model_type', type=str)
    parser.add_argument('--hash', type=str)
    parser.add_argument('--seed', type=int, default=1000005)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on %s.' % device)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'no_training':
        model_type_str = '-no_training'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)


    # data preparation and loading stuff
    rng = torch.Generator().manual_seed(hparams['seed'])

    DATA_ARGS = DATA_ARGS[:-3]
    train_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_train_neigh'])])))
    valid_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_valid_neigh'])])))
    test_data_id = '-'.join(list(sorted(['%s=%s' % (arg, hparams[arg]) for arg in DATA_ARGS] + ['n_neigh=%s' % (hparams['n_test_neigh'])])))

    normalize_str = '-normalize=%s' % (hparams['normalize']) if hparams['normalize'] is not None else ''
    train_data = torch.tensor(np.load(args.data_dir + '/projections-train%s-complex_sph=False-' % (normalize_str) + train_data_id + '.npy'))
    valid_data = torch.tensor(np.load(args.data_dir + '/projections-val%s-complex_sph=False-' % (normalize_str) + valid_data_id + '.npy'))
    test_data = torch.tensor(np.load(args.data_dir + '/projections-test%s-complex_sph=False-' % (normalize_str) + test_data_id + '.npy'))

    train_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-train-complex_sph=False-' + train_data_id + '.npy'))
    valid_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-val-complex_sph=False-' + valid_data_id + '.npy'))
    test_labels = torch.tensor(np.load(args.data_dir + '/aa_labels-test-complex_sph=False-' + test_data_id + '.npy'))

    try:
        train_frames = torch.tensor(np.load(args.data_dir + '/frames-train-complex_sph=False-' + train_data_id + '.npy')).view(-1, 3, 3)
        valid_frames = torch.tensor(np.load(args.data_dir + '/frames-val-complex_sph=False-' + valid_data_id + '.npy')).view(-1, 3, 3)
        test_frames = torch.tensor(np.load(args.data_dir + '/frames-test-complex_sph=False-' + test_data_id + '.npy')).view(-1, 3, 3)
    except:
        # dummy frames for compatibility
        train_frames = torch.zeros((train_labels.shape[0], 3, 3))
        valid_frames = torch.zeros((valid_labels.shape[0], 3, 3))
        test_frames = torch.zeros((test_labels.shape[0], 3, 3))
    
    def stringify(data_id):
        return '_'.join(list(map(lambda x: x.decode('utf-8'), list(data_id))))

    train_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-train-complex_sph=False-' + train_data_id + '.npy')))))
    valid_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-val-complex_sph=False-' + valid_data_id + '.npy')))))
    test_ids = np.array(list(map(stringify, list(np.load(args.data_dir + '/data_ids-test-complex_sph=False-' + test_data_id + '.npy')))))


    # filter by desired lmax and channels
    OnRadialFunctions = ZernickeRadialFunctions(hparams['rcut'], hparams['rmax']+1, hparams['lmax'], complex_sph = False)
    rst = RadialSphericalTensor(hparams['rmax']+1, OnRadialFunctions, hparams['lmax'], 1, 1)
    mul_rst = MultiChannelRadialSphericalTensor(rst, hparams['n_channels'])
    data_irreps = o3.Irreps(str(mul_rst))

    print('Data irreps: {}'.format(data_irreps), file=sys.stderr)
    print('Data dim: {}'.format(data_irreps.dim), file=sys.stderr)

    train_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(train_data, data_irreps, train_labels, train_frames, train_ids)
    valid_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(valid_data, data_irreps, valid_labels, valid_frames, valid_ids)
    test_dataset = NeighborhoodsDatasetWithConditioningAndIds__fibers(test_data, data_irreps, test_labels, test_frames, test_ids)

    print('%d training neighborhoods' % (len(train_dataset)), file=sys.stderr)
    print('%d validation neighborhoods' % (len(valid_dataset)), file=sys.stderr)

    train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=True, drop_last=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=False, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=100, generator=rng, shuffle=False, drop_last=True)

    dataloaders = {
        'train': train_dataloader,
        'valid': valid_dataloader,
        'test': test_dataloader
    }

    ## get w3j matrices
    with gzip.open(args.w3j_filepath, 'rb') as f:
        w3j_matrices = pickle.load(f)

    for key in w3j_matrices:
        if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
            if device is not None:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
            else:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
            w3j_matrices[key].requires_grad = False

    if hparams['model'] == 'cgvae_symmetric_simple_flexible':
        for key, value in hparams.items():
            print('{}\t{}'.format(key, value))
        print()
        if 'norm_balanced' not in hparams:
            hparams['norm_balanced'] = False
        cgvae = ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                                            hparams['latent_dim'],
                                                            hparams['net_lmax'],
                                                            hparams['n_cg_blocks'],
                                                            list(map(int, hparams['ch_size_list'].split(','))),
                                                            hparams['ls_nonlin_rule_list'].split(','),
                                                            hparams['ch_nonlin_rule_list'].split(','),
                                                            hparams['do_initial_linear_projection'],
                                                            hparams['ch_initial_linear_projection'],
                                                            w3j_matrices,
                                                            device,
                                                            lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                                            use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                                            use_batch_norm=hparams['use_batch_norm'],
                                                            norm_type=hparams['norm_type'], # None, layer, signal
                                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                                            norm_balanced=hparams['norm_balanced'],
                                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                                            norm_location=hparams['norm_location'], # first, between, last
                                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                                            learn_frame=hparams['learn_frame'],
                                                            is_vae=hparams['is_vae']).to(device)
    else:
        raise NotImplementedError()

    if args.model_type != 'no_training':
        cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
    else:
        print('Using untrained model.', file=sys.stderr)
    
    cgvae = cgvae.float()
    cgvae.eval()


    for split, dataloader in zip(args.splits, [dataloaders[s] for s in args.splits]):
        invariants, labels, learned_frames, rotations, images, rec_images, data_ids = [], [], [], [], [], [], []
        for i, (X, X_rec, y, rot, data_id) in tqdm(enumerate(dataloader)):
            X = dict_to_device(X, device)
            rot = rot.view(-1, 3, 3).float().to(device)

            (z_mean, z_log_var), scalar_features, learned_frame = cgvae.encode(X)

            z = z_mean

            # z_stddev = torch.exp(0.5 * z_log_var) # takes exponential function (log var -> stddev)
            # epsilon = torch.randn_like(z_stddev).to(device)        # sampling epsilon
            # z = z_mean + z_stddev*epsilon    

            if hparams['learn_frame']:
                x_reconst, scalar_features_reconst = cgvae.decode(z, learned_frame)
            else:
                x_reconst, scalar_features_reconst = cgvae.decode(z, rot)

            invariants.append(z.detach().cpu().numpy())

            if hparams['learn_frame']:
                learned_frames.append(learned_frame.reshape(-1, 1, 9).squeeze(1).detach().cpu().numpy())
            else:
                learned_frames.append(rot.reshape(-1, 1, 9).squeeze(1).cpu().numpy())
            
            labels.append(y.cpu().numpy())
            # rotations.append(rot.reshape(-1, 1, 9).squeeze().cpu().numpy())
            images.append(make_vector(X).detach().cpu().numpy())
            rec_images.append(make_vector(x_reconst).detach().cpu().numpy())
            data_ids.append(data_id)
            
        invariants_ND = np.vstack(invariants)
        learned_frames_N9 = np.vstack(learned_frames)
        labels_N = np.hstack(labels)
        rotations_N9 = np.array([0]) #np.vstack(rotations)
        data_ids_N = np.hstack(data_ids) # hstack because we assume the data_ids are represented in their stringified version. otherwise we would have to use vstack
        images_NF = np.vstack(images)
        rec_images_NF = np.vstack(rec_images)

        if not os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays')):
            os.mkdir(os.path.join(args.model_dir, args.hash, 'results_arrays'))
        
        np.savez(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, split)),
                        invariants_ND = invariants_ND,
                        learned_frames_N9 = learned_frames_N9,
                        labels_N = labels_N,
                        rotations_N9 = rotations_N9,
                        data_ids_N = data_ids_N,
                        images_NF = images_NF,
                        rec_images_NF = rec_images_NF)