

import os, sys
import gzip
import pickle
import argparse
import json
from tqdm import tqdm

import numpy as np
import torch
from torch.utils.data import DataLoader
import e3nn
from e3nn import o3
import matplotlib.pyplot as plt

from scipy.stats import spearmanr

sys.path.append('..')

from cgnet_fibers import ClebschGordanVAE_symmetric, ClebschGordanVAE_symmetric_simple, ClebschGordanVAE_symmetric_simple_flexible
from projections import real_sph_ift
from utils.data_getter import get_data_mnist
from utils import get_wigner_D_from_rot_matrix, rotate_signal, orthonormalize_frame
from utils.argparse_utils import *
from utils.data_utils import Shrec17Dataset__fibers

import umap


def dict_to_device(adict, device):
    for key in adict:
        adict[key] = adict[key].float().to(device)
    return adict

def make_vector(x):
    x_vec = []
    for l in sorted(list(x.keys())):
        x_vec.append(x[l].reshape((x[l].shape[0], -1)))
    return torch.cat(x_vec, dim=-1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--w3j_filepath', type=str, default='../cg_coefficients/w3j_matrices-lmax=14-version=0.5.0.pkl')
    parser.add_argument('--data_dir')
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--splits', type=comma_sep_str_list, default='train,valid,test')
    parser.add_argument('--model_type', type=str, default='best')
    parser.add_argument('--hash', type=str)
    parser.add_argument('--seed', type=int, default=1000005)

    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Running on %s.' % device)

    if args.model_type == 'best':
        model_type_str = ''
        model_name = 'best_model.pt'
    elif args.model_type == 'best_04':
        model_type_str = ''
        model_name = 'best_model_04.pt'
    elif args.model_type == 'best_05':
        model_type_str = '-best_05'
        model_name = 'best_model_05.pt'
    elif args.model_type == 'best_06':
        model_type_str = '-best_06'
        model_name = 'best_model_06.pt'
    elif args.model_type == 'best_higher_kld':
        model_type_str = '-best_model_higher_kld'
        model_name = 'best_model_higher_kld.pt'
    elif args.model_type == 'lowest_rec_loss':
        model_type_str = '-lowest_rec_loss'
        model_name = 'lowest_rec_loss_model.pt'
    elif args.model_type == 'final':
        model_type_str = '-final_model'
        model_name = 'final_model.pt'
    elif args.model_type == 'lowest_total_loss_with_final_kl_model':
        model_type_str = '-lowest_total_loss_with_final_kl_model'
        model_name = 'lowest_total_loss_with_final_kl_model.pt'

    with open(os.path.join(args.model_dir, args.hash, 'hparams.json'), 'r') as f:
        hparams = json.load(f)

    if 'shrec17' in args.model_dir:
        rng = torch.Generator().manual_seed(hparams['seed'])
        data_irreps = (6*o3.Irreps.spherical_harmonics(hparams['lmax'], 1)).sort().irreps.simplify()

        data_id = '-b=%d-perturbed=%s-random_rotations=%s-random_translation=%.2f' % (hparams['bandwidth'], hparams['perturbed'], hparams['random_rotations'], hparams['random_translation'])
        data_id_2 = '-b=%d-lmax=%d-normalize=%s' % (hparams['bandwidth'], hparams['lmax'], hparams['normalize'])
        
        with gzip.open(os.path.join(args.data_dir, 'shrec17_real_sph_ft%s%s.gz' % (data_id, data_id_2)), 'rb') as f:
            data = pickle.load(f)
        
        train_dataset = Shrec17Dataset__fibers(torch.tensor(data['train']['projections']), data_irreps, torch.tensor(data['train']['labels']), torch.zeros(data['train']['labels'].shape[0]))
        valid_dataset = Shrec17Dataset__fibers(torch.tensor(data['valid']['projections']), data_irreps, torch.tensor(data['valid']['labels']), torch.zeros(data['valid']['labels'].shape[0]))
        print(data['test']['ids'])
        # try:
        test_dataset = Shrec17Dataset__fibers(torch.tensor(data['test']['projections']), data_irreps, torch.tensor(data['test']['labels']), data['test']['ids'])
        # except:
        #     print('error in test ids', file=sys.stderr)
        #     test_dataset = Shrec17Dataset__fibers(torch.tensor(data['test']['projections']), data_irreps, torch.tensor(data['test']['labels']), torch.zeros(data['test']['labels'].shape[0]))


        print('%d training neighborhoods' % (len(train_dataset)), file=sys.stderr)
        print('%d validation neighborhoods' % (len(valid_dataset)), file=sys.stderr)
        print('%d test neighborhoods' % (len(test_dataset)), file=sys.stderr)

        train_dataloader = DataLoader(train_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=True, drop_last=True)
        valid_dataloader = DataLoader(valid_dataset, batch_size=hparams['batch_size'], generator=rng, shuffle=False, drop_last=True)
        test_dataloader = DataLoader(test_dataset, batch_size=32, generator=rng, shuffle=False, drop_last=True)
    
        dataloaders = {'train': train_dataloader, 'valid': valid_dataloader, 'test': test_dataloader}

    ## get w3j matrices
    with gzip.open(args.w3j_filepath, 'rb') as f:
        w3j_matrices = pickle.load(f)

    for key in w3j_matrices:
        if key[0] <= hparams['net_lmax'] and key[1] <= hparams['net_lmax'] and key[2] <= hparams['net_lmax']:
            if device is not None:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float().to(device)
            else:
                w3j_matrices[key] = torch.tensor(w3j_matrices[key]).float()
            w3j_matrices[key].requires_grad = False

    if hparams['model'] == 'cgvae_symmetric_simple_flexible':
        for key, value in hparams.items():
            print('{}\t{}'.format(key, value))
        print()
        
        if 'norm_balanced' not in hparams:
            hparams['norm_balanced'] = False
        
        print(hparams['bottleneck_hidden_dims'])
        print(list(map(int, hparams['bottleneck_hidden_dims'].split(','))))
        
        cgvae = ClebschGordanVAE_symmetric_simple_flexible(data_irreps,
                                            hparams['latent_dim'],
                                            hparams['net_lmax'],
                                            hparams['n_cg_blocks'],
                                            list(map(int, hparams['ch_size_list'].split(','))),
                                            hparams['ls_nonlin_rule_list'].split(','),
                                            hparams['ch_nonlin_rule_list'].split(','),
                                            hparams['do_initial_linear_projection'],
                                            hparams['ch_initial_linear_projection'],
                                            w3j_matrices,
                                            device,
                                            lmax_list=list(map(int, hparams['lmax_list'].split(','))),
                                            # bottleneck_hidden_dims=list(map(int, hparams['bottleneck_hidden_dims'].split(','))),
                                            use_additive_skip_connections=hparams['use_additive_skip_connections'],
                                            use_batch_norm=hparams['use_batch_norm'],
                                            norm_type=hparams['norm_type'], # None, layer, signal
                                            normalization=hparams['normalization'], # norm, component -> only considered if norm_type is not none
                                            norm_balanced=hparams['norm_balanced'],
                                            norm_affine=hparams['norm_affine'], # None, {True, False} -> for layer_norm, {unique, per_l, per_feature} -> for signal_norm
                                            norm_nonlinearity=hparams['norm_nonlinearity'], # None (identity), identity, relu, swish, sigmoid -> only for layer_norm
                                            norm_location=hparams['norm_location'], # first, between, last
                                            linearity_first=hparams['linearity_first'], # currently only works with this being false
                                            filter_symmetric=hparams['filter_symmetric'], # whether to exclude duplicate pairs of l's from the tensor product nonlinearity
                                            x_rec_loss_fn=hparams['x_rec_loss_fn'], # mse, mse_normalized, cosine
                                            do_final_signal_norm=hparams['do_final_signal_norm'],
                                            learn_frame=hparams['learn_frame'],
                                            is_vae=hparams['is_vae']).to(device)
    else:
        raise NotImplementedError()

    cgvae.load_state_dict(torch.load(os.path.join(args.model_dir, args.hash, model_name), map_location=torch.device(device)))
    cgvae = cgvae.float()
    cgvae.eval()


    for split, dataloader in zip(args.splits, [dataloaders[s] for s in args.splits]):
        invariants, labels, learned_frames, rotations, images, rec_images, ids = [], [], [], [], [], [], []
        for i, (X, X_rec, y, idd) in tqdm(enumerate(dataloader)):
            X = dict_to_device(X, device)
            # rot = rot.view(-1, 3, 3).float().to(device)

            (z_mean, z_log_var), scalar_features, learned_frame = cgvae.encode(X)

            z = z_mean

            # z_stddev = torch.exp(0.5 * z_log_var) # takes exponential function (log var -> stddev)
            # epsilon = torch.randn_like(z_stddev).to(device)        # sampling epsilon
            # z = z_mean + z_stddev*epsilon    

            if hparams['learn_frame']:
                x_reconst, scalar_features_reconst = cgvae.decode(z, learned_frame)
            else:
                x_reconst, scalar_features_reconst = cgvae.decode(z, rot)

            invariants.append(z.detach().cpu().numpy())

            if hparams['learn_frame']:
                learned_frames.append(learned_frame.reshape(-1, 1, 9).squeeze(1).detach().cpu().numpy())
            else:
                learned_frames.append(rot.reshape(-1, 1, 9).squeeze(1).cpu().numpy())
            
            labels.append(y.cpu().numpy())
            # rotations.append(rot.reshape(-1, 1, 9).squeeze().cpu().numpy())
            images.append(make_vector(X).detach().cpu().numpy())
            rec_images.append(make_vector(x_reconst).detach().cpu().numpy())
            ids.append(np.array(idd))
            
        invariants_ND = np.vstack(invariants)
        learned_frames_N9 = np.vstack(learned_frames)
        labels_N = np.hstack(labels)
        rotations_N9 = np.array([0]) #np.vstack(rotations)
        images_NF = np.vstack(images)
        rec_images_NF = np.vstack(rec_images)
        ids_N = np.hstack(ids)

        print(ids_N.shape)
        print(ids_N[:10])

        if not os.path.exists(os.path.join(args.model_dir, args.hash, 'results_arrays')):
            os.mkdir(os.path.join(args.model_dir, args.hash, 'results_arrays'))
        
        np.savez(os.path.join(args.model_dir, args.hash, 'results_arrays/inference%s-split=%s.npz' % (model_type_str, split)),
                        invariants_ND = invariants_ND,
                        learned_frames_N9 = learned_frames_N9,
                        labels_N = labels_N,
                        rotations_N9 = rotations_N9,
                        images_NF = images_NF,
                        rec_images_NF = rec_images_NF,
                        ids_N = ids_N)
