
import os, sys
import gzip
import pickle
import numpy as np
import torch
import e3nn
from e3nn import o3
import argparse
from tqdm import tqdm

sys.path.append(os.path.join(sys.path[0], '../../..'))
from utils.argparse_utils import *
from projections import real_sph_ft, real_sph_ift
from lie_learn.spaces.spherical_quadrature import estimate_spherical_quadrature_weights


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--input_grid', type=str, default='../../spherical_grids/ba_grid-b=%d.gz',
                        help='Path to file containing S2 grid onto which signals are defined. Bandwidth must be inserted via string formatting. Must be in *.gz format.')

    parser.add_argument('--quad_weights', type=str, default='../../spherical_grids/quad_weights-b=%d-lmax=%d.npy',
                        help='Path to file containing quadrature weights for the grid onto which signals are defined. Bandwidth and lmax must be inserted via string formatting. Must be in *.npy format.')

    parser.add_argument('--input_data', type=str,
                        help='Path to file containing data on the S2 grid. Data_id must be inserted via string formatting.')
    
    parser.add_argument('--data_id', type=str, default='-b=%d-perturbed=True-random_rotations=False-random_translation=0.00',
                        help='Identifier for the data. Bandwidth must be inserted via string formatting.')

    parser.add_argument('--output_file', type=str,
                        help='Path to desired output file. Data_id, normalize, bandwidth lmax must be inserted via string formatting. Must be in *.gz format.')

    parser.add_argument('--bandwidth', type=int, default=90,
                        help='Bandwidth of the S2 grid onto which signals are defined.')

    parser.add_argument('--lmax', type=int, default=14,
                        help='Maximum value of l for which to compute spherical harmonics projections.')

    parser.add_argument('--normalize', type=optional_str, default=None,
                        help='Per-datapoint and per-channel normalization on the forward fourier transform.')

    parser.add_argument('--complex_sph', type=str_to_bool, default=False,
                        help='Whether to projct signals onto the complex spherical harmonics basis. Currently not implemented.')

    parser.add_argument('--batch_size', type=int, default=1000)
    
    args = parser.parse_args()

    data_id = args.data_id % (args.bandwidth)

    with gzip.open(args.input_grid % (args.bandwidth), 'rb') as f:
        ba_grid = pickle.load(f)
    
    quad_weights = np.load(args.quad_weights % (args.bandwidth, args.lmax))
    
    try:
        with gzip.open(args.input_data % (data_id), 'rb') as f:
            dataset = pickle.load(f)
        print('gzip succeeded')
    except:
        print('gzip failed, using numpy')
        train_images = np.load(args.input_data % ('-IMAGES-split=train' + data_id))
        train_labels = np.load(args.input_data % ('-LABELS-split=train' + data_id))
        valid_images = np.load(args.input_data % ('-IMAGES-split=valid' + data_id))
        valid_labels = np.load(args.input_data % ('-LABELS-split=valid' + data_id))
        test_images = np.load(args.input_data % ('-IMAGES-split=test' + data_id))
        test_labels = np.load(args.input_data % ('-LABELS-split=test' + data_id))
        test_ids = np.load(args.input_data % ('-IDS-split=test' + data_id))
        dataset = {
            'train': {
                'images': train_images,
                'labels': train_labels,
                'ids': None
            },
            'valid': {
                'images': valid_images,
                'labels': valid_labels,
                'ids': None
            },
            'test': {
                'images': test_images,
                'labels': test_labels,
                'ids': test_ids
            }
        }
    
    
    projections_dataset = {}
    for split in ['train', 'valid', 'test']:
        projections_dataset[split] = {}
        projections_dataset[split]['labels'] = dataset[split]['labels']
        all_signals = dataset[split]['images']

        if args.complex_sph:
            raise NotImplementedError('Projection on complex SPH basis not yet implemented.')
        else:
            num_batches = all_signals.shape[0] // args.batch_size
            num_channels = all_signals.shape[1]
            projections, norm_factors = [], []
            for i in tqdm(range(num_batches)):
                signals = all_signals[i*args.batch_size : (i+1)*args.batch_size]

                proj_batch, norms_batch = [], []
                for ch in range(num_channels):
                    ch_batch, norm_factors = real_sph_ft(signals[:, ch, :], ba_grid, args.lmax, quad_weights_N=quad_weights, normalization=None)
                    proj_batch.append(ch_batch.numpy())

                    if args.normalize:
                        norms_batch.append(norm_factors.numpy())

                proj_batch = np.stack(proj_batch, axis=1)
                projections.append(proj_batch)

                if args.normalize:
                    norms_batch = np.stack(norms_batch, axis=1)
                    norm_factors.append(norms_batch)
            
            # process remainders
            signals = all_signals[(i+1)*args.batch_size :]
            proj_batch, norms_batch = [], []
            for ch in range(num_channels):
                ch_batch, norm_factors = real_sph_ft(signals[:, ch, :], ba_grid, args.lmax, quad_weights_N=quad_weights, normalization=None)
                proj_batch.append(ch_batch.numpy())
                if args.normalize:
                    norms_batch.append(norm_factors.numpy())
            proj_batch = np.stack(proj_batch, axis=1)
            projections.append(proj_batch)
            if args.normalize:
                norms_batch = np.stack(norms_batch, axis=1)
                norm_factors.append(norms_batch)
            
            projections_dataset[split]['projections'] = np.concatenate(projections, axis=0)
            if args.normalize:
                projections_dataset[split]['norm_factors'] = np.concatenate(norm_factors, axis=0)


    with gzip.open(args.output_file % (data_id, args.bandwidth, args.lmax, args.normalize), 'wb') as f:
        pickle.dump(projections_dataset, f)


    
