
import torch
import numpy as np

import gzip
import pickle


INPUT_OPTIONS = {
    'RR-avg_sqrt_power': ['../data/mnist/data/s2_mnist-cz=10-b=30.gz', '../data/mnist/data/real_sph_mnist-cz=10-b=30-lmax=%d-normalize=avg_sqrt_power-quad_weights=True.gz'],
    'RR-avg_sqrt_power-1,7': ['../data/mnist/data/s2_mnist-cz=10-b=30.gz', '../data/mnist/data/real_sph_mnist-cz=10-b=30-lmax=%d-normalize=avg_sqrt_power-quad_weights=True-labels=1,7.gz'],

    'NRR-avg_sqrt_power': ['../data/mnist/data/s2_mnist-cz=10-b=30.gz', '../data/mnist/data/real_sph_mnist-no_rotate_train-cz=10-b=30-lmax=%d-normalize=avg_sqrt_power-quad_weights=True.gz'],
    'NRR-avg_sqrt_power-1,7': ['../data/mnist/data/s2_mnist-cz=10-b=30.gz', '../data/mnist/data/real_sph_mnist-no_rotate_train-cz=10-b=30-lmax=%d-normalize=avg_sqrt_power-quad_weights=True-labels=1,7.gz']
}

def get_grids_mnist():
    print('Getting grids...')
    with gzip.open('../data/spherical_grids/ba_grid-b=30.gz', 'rb') as f:
        ba_grid = pickle.load(f)

        # flatten grid
        ba_grid = list(ba_grid)
        for i in range(2):
            ba_grid[i] = ba_grid[i].flatten()
        ba_grid = tuple(ba_grid)
    
    with gzip.open('../data/spherical_grids/xyz_grid-b=30.gz', 'rb') as f:
        xyz_grid = pickle.load(f)
    print('Done getting grids.')
    return ba_grid, xyz_grid

def get_data_mnist(input_type, get_grids=False, get_s2=False, lmax=10):

    ## get ft transformed data
    print('Getting ft data...')
    ft_data_path = INPUT_OPTIONS[input_type][1] % (lmax)
    with gzip.open(ft_data_path, 'rb') as f:
        ft_data = pickle.load(f)

    ## get grids if requested
    if get_grids:
        print('Getting grids...')
        with gzip.open('../data/spherical_grids/ba_grid-b=30.gz', 'rb') as f:
            ba_grid = pickle.load(f)

            # flatten grid
            ba_grid = list(ba_grid)
            for i in range(2):
                ba_grid[i] = ba_grid[i].flatten()
            ba_grid = tuple(ba_grid)
        
        with gzip.open('../data/spherical_grids/xyz_grid-b=30.gz', 'rb') as f:
            xyz_grid = pickle.load(f)
            
    else:
        ba_grid = None
        xyz_grid = None

    # get s2_data if requested
    if get_s2:
        print('Getting s2 data...')
        s2_data_path = INPUT_OPTIONS[input_type][0]
        with gzip.open(s2_data_path, 'rb') as f:
            s2_data = pickle.load(f)

            # split data in train and valid with the seed used in the projections, unless already done
            if 'valid' not in s2_data:
                signals = s2_data['train']['images']
                unpr_images = s2_data['train']['unprojected_images']
                rotations = s2_data['train']['rotations']
                labels = s2_data['train']['labels']

                rng = np.random.default_rng(420420420)
                N = signals.shape[0]
                idxs = rng.permutation(N)
                # default to 90-10 split
                train_idxs = idxs[:int(0.9 * N)]
                valid_idxs = idxs[int(0.9 * N):]

                s2_data['train']['images'] = signals[train_idxs]
                s2_data['train']['unprojected_images'] = unpr_images[train_idxs]
                s2_data['train']['rotations'] = rotations[train_idxs]
                s2_data['train']['labels'] = labels[train_idxs]

                s2_data['valid'] = {}
                s2_data['valid']['images'] = signals[valid_idxs]
                s2_data['valid']['unprojected_images'] = unpr_images[valid_idxs]
                s2_data['valid']['rotations'] = rotations[valid_idxs]
                s2_data['valid']['labels'] = labels[valid_idxs]
            
            # flatten data
            for split in ['train', 'valid', 'test']:
                N = s2_data[split]['images'].shape[0]
                s2_data[split]['images'] = s2_data[split]['images'].reshape((N, -1))
            
            with gzip.open(s2_data_path, 'wb') as f:
                pickle.dump(s2_data, f)
    else:
        s2_data = None
    
    print('Done getting data.')
    
    return ft_data, s2_data, (ba_grid, xyz_grid)
    
        

