
import numpy as np
import torch
import e3nn
from e3nn import o3

import gzip, pickle
import matplotlib.pyplot as plt

import argparse
import sys, os
from tqdm import tqdm
sys.path.append('../../..')
from utils.argparse_utils import *

def combine(tensors, irreps):
    batch_size = tensors.shape[0]
    new_tensors = []
    lower_bound = 0
    for mul, ir in irreps:
        num_values = 2*ir.l+1
        new_tensors.append(tensors[:, :, lower_bound : lower_bound + num_values].reshape(batch_size, -1, 1).squeeze())
        lower_bound += num_values
    
    combined_tensors = torch.cat(new_tensors, dim=-1)

    return combined_tensors


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--data_file', type=str, default='real_sph_mnist-no_rotate_train-no_rotate_test-b=30-lmax=10-normalize=None-quad_weights=True.gz')
    parser.add_argument('--lmax', type=int, default=10)
    parser.add_argument('--balanced', type=str_to_bool_or_float, default=False)
    args = parser.parse_args()

    with gzip.open(os.path.join(args.data_dir, args.data_file), 'rb') as f:
        data_dict = pickle.load(f)

    N_CHANNELS = 1
    data_irreps = (N_CHANNELS*o3.Irreps.spherical_harmonics(args.lmax, 1)).sort().irreps.simplify()
    ls_indices = torch.cat([torch.tensor(data_irreps.ls)[torch.tensor(data_irreps.ls) == l].repeat(2*l+1) for l in sorted(list(set(data_irreps.ls)))]).type(torch.float)

    if args.balanced:
        multiplicative_constant = 1.0 if isinstance(args.balanced, bool) else args.balanced
        balancing_constant = N_CHANNELS * (args.lmax+1) / multiplicative_constant
        normalized_data_file = args.data_file.replace('normalize=None', 'normalize=avg_sqrt_power_balanced_times_%.2f' % (multiplicative_constant))
    else:
        balancing_constant = 1.0
        normalized_data_file = args.data_file.replace('normalize=None', 'normalize=avg_sqrt_power')

    batch_size = 1000
    train_norm_factors = []
    normalized_data_dict = {}
    for split in data_dict:
        normalized_data_dict[split] = {}
        normalized_data_dict[split]['labels'] = data_dict[split]['labels']
        normalized_data_dict[split]['rotations'] = data_dict[split]['rotations']

        # need to do this whenever data is stored in deparate channels (1st dimension)
        if len(data_dict[split]['projections'].shape) == 3:
            data_dict[split]['projections'] = combine(torch.tensor(data_dict[split]['projections']), data_irreps)

        if split == 'train':
            all_signals = data_dict[split]['projections']
            print(all_signals.shape)
            num_batches = all_signals.shape[0] // batch_size
            projections = []
            for i in tqdm(range(num_batches)):
                signals = all_signals[i*batch_size : (i+1)*batch_size]
                norm_factors = torch.sqrt(torch.einsum('bf,bf,f->b', signals, signals, 1.0 / (2*ls_indices + 1)) / balancing_constant)
                train_norm_factors.append(norm_factors)
            
            # final batch for the remaining signals
            if (all_signals.shape[0] % batch_size) > 0:
                signals = all_signals[(i+1)*batch_size:]
                norm_factors = torch.sqrt(torch.einsum('bf,bf,f->b', signals, signals, 1.0 / (2*ls_indices + 1)) / balancing_constant)
                train_norm_factors.append(norm_factors)

            train_norm_factors = torch.mean(torch.cat(train_norm_factors, dim=-1))

        normalized_data_dict[split]['projections'] = data_dict[split]['projections'] / train_norm_factors

    # save normalized data
    with gzip.open(os.path.join(args.data_dir, normalized_data_file), 'wb') as f:
        pickle.dump(normalized_data_dict, f)
