import os
import bisect
import numpy as np
from kamitani_data_handler import kamitani_data_handler as data_handler


def get_data(subject_file):
    project_dir = "/hdd1/pkyriakis/eppe/NeurIPs22/dataset1"
    images_npz = os.path.join(project_dir,"images_112.npz")
    kamitani_data_mat = os.path.join(project_dir, subject_file) 

    handler = data_handler(matlab_file = kamitani_data_mat, train_img_csv=os.path.join(project_dir, 'imageID_training.csv'), test_img_csv=os.path.join(project_dir, 'imageID_test.csv'))
    X_train,_,X_test = handler.get_data(roi = 'ROI_VC')
    labels_train, labels_test = handler.get_labels()

    coords = []
    coords.append(handler.get_meta_field('voxel_x'))
    coords.append(handler.get_meta_field('voxel_y'))
    coords.append(handler.get_meta_field('voxel_z'))
    coords = np.array(coords)

    file = np.load(images_npz)
    Y_train = file['train_images']
    Y_train = Y_train[labels_train]
    Y_test = file['test_images']   

    return X_train, X_test, Y_train, Y_test, coords 

'''
    Takes the coordinates and voxel intensity and transorms it into a 3d mesh
    @param coords np.array of shape (3, NUM_OF_VOXELS)
    @param intensity np.array of shape (NUM_OF_VOXELS, )
    @param samples_per axis int
'''
def coord2grid(coords, intensity, samples_per_axis=64):
    grid = np.zeros(shape=(samples_per_axis,samples_per_axis,samples_per_axis))
    m, M = np.min(coords, axis=1), np.max(coords, axis=1)
    l = np.zeros(shape=(3,samples_per_axis-1))
    
    for i in range(coords.shape[0]):
        l[i] = np.linspace(m[i], M[i], samples_per_axis-1)

    colisions = 0
    for i in range(coords.shape[1]):
        coord = coords[:,i]
        inds = []
        for j in range(coords.shape[0]):
            ind = bisect.bisect(l[j].tolist(), coord[j])
            inds.append(ind)
        if (grid[inds[0], inds[1], inds[2]]):
            colisions += 1
        grid[inds[0], inds[1], inds[2]] =  intensity[i]
    print(np.mean(colisions))
    return grid, colisions

'''
    Converts all samples into grid structure 
    @param X np.array of shape (NUM_OF_SAMPLES,NUM_OF_VOXELS)
    @param samples_per_axis int 
    @returns np.array os shape (NUM_OF_SAMPLES,samples_per_axis,samples_per_axis,samples_per_axis)
'''
def data2grid(X, coords, samples_per_axis=64):
    N = X.shape[0]
    X_grid = np.zeros(shape=(N,samples_per_axis,samples_per_axis,samples_per_axis,))
    cols = []
    for i in range(X.shape[0]):
        out, col = coord2grid(coords, X[i], samples_per_axis=samples_per_axis)
        X_grid[i] = out
        cols.append(col)
    return X_grid

files = ["Subject1", "Subject2", "Subject3", "Subject4", "Subject5"]
samples_per_axis = 46
for file in files:
    print("Working on " + file)
    X_train, X_test, Y_train, Y_test, coords = get_data(file + ".mat")
    X_train_grid = data2grid(X_train, coords, samples_per_axis=samples_per_axis)
    X_test_grid = data2grid(X_test, coords, samples_per_axis=samples_per_axis)
    np.savez(file + ".npz", X_train=X_train_grid, X_test=X_test_grid, Y_train=Y_train, Y_test=Y_test)