
import numpy as np
import torch
import e3nn
from e3nn import o3

import pickle
import matplotlib.pyplot as plt

import argparse
import sys, os
from tqdm import tqdm
sys.path.append('../..')
from utils.argparse_utils import *
from projections import ZernickeRadialFunctions, RadialSphericalTensor, MultiChannelRadialSphericalTensor

# projections-train-complex_sph=False-get_H=False-get_SASA=False-get_charge=False-lmax=6-n_channels=4-n_neigh=126025-rcut=10.0-rmax=40-rst_normalization=square
# projections-val-complex_sph=False-get_H=False-get_SASA=False-get_charge=False-lmax=6-n_channels=4-n_neigh=30553-rcut=10.0-rmax=40-rst_normalization=square

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--lmax', type=int, default=4)
    parser.add_argument('--n_channels', type=int, default=4)
    parser.add_argument('--n_train_neigh', type=int, default=5000)
    parser.add_argument('--n_valid_neigh', type=int, default=5000)
    parser.add_argument('--n_test_neigh', type=int, default=10000)
    parser.add_argument('--rmax', type=int, default=20)
    parser.add_argument('--get_H', type=str_to_bool, default=False)
    parser.add_argument('--get_SASA', type=str_to_bool, default=False)
    parser.add_argument('--get_charge', type=str_to_bool, default=False)

    parser.add_argument('--rst_normalization', type=str, default='square')
    parser.add_argument('--rcut', type=float, default=10.0)
    parser.add_argument('--complex_sph', type=str_to_bool, default=False)

    parser.add_argument('--balanced', type=str_to_bool_or_float, default=False)

    args = parser.parse_args()

    try:
        train_data_id = 'complex_sph=%s-get_H=%s-get_SASA=%s-get_charge=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.get_H, args.get_SASA, args.get_charge, args.lmax, args.n_channels, args.n_train_neigh, args.rcut, args.rmax, args.rst_normalization)
        valid_data_id = 'complex_sph=%s-get_H=%s-get_SASA=%s-get_charge=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.get_H, args.get_SASA, args.get_charge, args.lmax, args.n_channels, args.n_valid_neigh, args.rcut, args.rmax, args.rst_normalization)
        test_data_id = 'complex_sph=%s-get_H=%s-get_SASA=%s-get_charge=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.get_H, args.get_SASA, args.get_charge, args.lmax, args.n_channels, args.n_test_neigh, args.rcut, args.rmax, args.rst_normalization)
        train_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-train-{}.npy'.format(train_data_id)))).float()
        valid_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-val-{}.npy'.format(valid_data_id)))).float()
        test_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-test-{}.npy'.format(test_data_id)))).float()
    except:
        train_data_id = 'complex_sph=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.lmax, args.n_channels, args.n_train_neigh, args.rcut, args.rmax, args.rst_normalization)
        valid_data_id = 'complex_sph=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.lmax, args.n_channels, args.n_valid_neigh, args.rcut, args.rmax, args.rst_normalization)
        test_data_id = 'complex_sph=%s-lmax=%d-n_channels=%d-n_neigh=%d-rcut=%.1f-rmax=%d-rst_normalization=%s' % (args.complex_sph, args.lmax, args.n_channels, args.n_test_neigh, args.rcut, args.rmax, args.rst_normalization)
        train_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-train-{}.npy'.format(train_data_id)))).float()
        valid_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-val-{}.npy'.format(valid_data_id)))).float()
        test_data = torch.Tensor(np.load(os.path.join(args.data_dir, 'projections-test-{}.npy'.format(test_data_id)))).float()

    data_dict = {'train': train_data, 'valid': valid_data, 'test': test_data}


    # filter by desired lmax and channels
    OnRadialFunctions = ZernickeRadialFunctions(args.rcut, args.rmax+1, args.lmax, complex_sph=args.complex_sph)
    rst = RadialSphericalTensor(args.rmax+1, OnRadialFunctions, args.lmax, 1, 1)
    mul_rst = MultiChannelRadialSphericalTensor(rst, args.n_channels)
    data_irreps = o3.Irreps(str(mul_rst))
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)

    if args.balanced:
        multiplicative_constant = 1.0 if isinstance(args.balanced, bool) else args.balanced
        # compute avg number of channels (not constant for these projections)
        muls = []
        for mul, ir in data_irreps:
            muls.append(mul)
        avg_n_channels = np.mean(muls)
        balancing_constant = avg_n_channels * (args.lmax+1) / multiplicative_constant
        normalize_str_for_data_id = 'normalize=avg_sqrt_power_balanced_times_%.2f-n_train_neigh=%d' % (multiplicative_constant, args.n_train_neigh)
    else:
        balancing_constant = 1.0
        normalize_str_for_data_id = 'normalize=avg_sqrt_power-n_train_neigh=%d' % (args.n_train_neigh)

    batch_size = 100
    train_norm_factors = []
    normalized_data_dict = {}
    for split in data_dict:
        if split == 'train':
            all_signals = data_dict[split]
            num_batches = all_signals.shape[0] // batch_size
            projections = []
            i = 0
            for _ in tqdm(range(num_batches)):
                i += 1
                signals = all_signals[i*batch_size : (i+1)*batch_size]
                norm_factors = torch.sqrt(torch.einsum('bf,bf,f->b', signals, signals, 1.0 / (2*ls_indices + 1)))
                train_norm_factors.append(norm_factors)

            # final batch for the remaining signals
            if (all_signals.shape[0] % batch_size) > 0:
                print('extra stuff')
                signals = all_signals[(i+1)*batch_size:]
                norm_factors = torch.sqrt(torch.einsum('bf,bf,f->b', signals, signals, 1.0 / (2*ls_indices + 1)))
                train_norm_factors.append(norm_factors)

            train_norm_factors = torch.mean(torch.cat(train_norm_factors, dim=-1))

        normalized_data_dict[split] = data_dict[split] / train_norm_factors

    # save normalized data
    np.save(os.path.join(args.data_dir, 'projections-train-{}-{}.npy'.format(normalize_str_for_data_id, train_data_id)), normalized_data_dict['train'].numpy())
    np.save(os.path.join(args.data_dir, 'projections-val-{}-{}.npy'.format(normalize_str_for_data_id, valid_data_id)), normalized_data_dict['valid'].numpy())
    np.save(os.path.join(args.data_dir, 'projections-test-{}-{}.npy'.format(normalize_str_for_data_id, test_data_id)), normalized_data_dict['test'].numpy())
