import os

import nibabel as nib
import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.sparse import coo_matrix
from scipy.sparse.csgraph import dijkstra

from src.utils import log


def save_mid_surface(subject_id,
                     path_freesurfer,
                     hemi,
                     depth):
    path_white = os.path.join(path_freesurfer,
                              subject_id,
                              'surf',
                              f'{hemi}.white')
    path_pial = os.path.join(path_freesurfer,
                             subject_id,
                             'surf',
                             f'{hemi}.pial')
    path_mid = os.path.join(path_freesurfer,
                            subject_id,
                            'surf',
                            f'{hemi}.mid')

    coords_white, faces_white, vol_info_white = \
        nib.freesurfer.read_geometry(path_white, read_metadata=True)
    coords_pial,  faces_pial, vol_info_pial = \
        nib.freesurfer.read_geometry(path_pial, read_metadata=True)
    assert all([np.all(vol_info_pial[key] == vol_info_white[key])
                for key in vol_info_pial.keys()])
    assert np.all(faces_white == faces_pial)

    coords_mid = coords_white * (1.0 - depth) + coords_pial * depth
    nib.freesurfer.write_geometry(path_mid,
                                  coords_mid,
                                  faces_white,
                                  volume_info=vol_info_white)


def calculate_geodesic_neighbors(subject_id,
                                 path_freesurfer,
                                 hemi,
                                 num_neighbors=100):
    path_mid = os.path.join(path_freesurfer, subject_id, 'surf', f'{hemi}.mid')
    coords_mid, faces_mid = nib.freesurfer.read_geometry(path_mid)
    num_verts = coords_mid.shape[0]
    rows = np.concatenate([faces_mid[:, [0, 1, 1, 2, 2, 0]].ravel()])
    cols = np.concatenate([faces_mid[:, [1, 0, 2, 1, 0, 2]].ravel()])
    dists = np.linalg.norm(coords_mid[rows] - coords_mid[cols], axis=1)
    graph = coo_matrix((dists, (rows, cols)), shape=(num_verts, num_verts)).tocsr()

    def compute_neighbors(v):
        dist_matrix = dijkstra(csgraph=graph, indices=v)            
        nbrs = np.argsort(dist_matrix)[:num_neighbors]
        return nbrs, dist_matrix[nbrs]

    results = Parallel(n_jobs=-1)(
        delayed(compute_neighbors)(v) for v in tqdm(range(num_verts),
                                                    desc=f'{subject_id} {hemi}',)
    )

    neighbors = np.vstack([r[0] for r in results])
    distances = np.vstack([r[1] for r in results])

    dict_neighbors_vertices = {}
    dict_neighbors_vertices['distances'] = distances
    dict_neighbors_vertices['indices'] = neighbors

    return dict_neighbors_vertices


def main():
    depth = 0.5
    dict_dataset = {
        'THINGS-fMRI': ['S1', 'S2', 'S3'],
        'BOLD5000': ['CSI1', 'CSI2', 'CSI3', 'CSI4'],
    }

    # path_datasets = os.path.join(..., 'Documents', 'datasets')
    path_datasets = os.path.join(..., 'datasets')
    assert os.path.exists(path_datasets)
    
    for dataset in dict_dataset.keys():
        path_freesurfer = os.path.join(path_datasets, dataset, 'freesurfer')
        assert os.path.exists(path_freesurfer)
        for subject_id in dict_dataset[dataset]:
            for hemi in ['lh', 'rh']:
                log(f'Running subject {subject_id} {hemi} ({dataset})...')

                path_processed_dataset = \
                    os.path.join(path_datasets,
                                 dataset,
                                 'processed',
                                 subject_id)
                path_neighbors = \
                    os.path.join(path_processed_dataset,
                                 'neighbors',
                                 'surf')
                os.makedirs(path_neighbors, exist_ok=True)
                path_neighbors_filename = os.path.join(path_neighbors,
                                                       f'{subject_id}_neighbors_surf_{hemi}.npz')

                assert os.path.exists(path_processed_dataset)
                if os.path.exists(path_neighbors_filename):
                    log(f'Neighbors already computed for {subject_id} {hemi}...\n')
                    continue

                log(f'Saving mid-surface for {subject_id} {hemi} ({dataset})...')
                save_mid_surface(subject_id,
                                 path_freesurfer,
                                 hemi,
                                 depth=depth)

                log(f'Computing mid-surface for {subject_id} {hemi} ({dataset})...')
                dict_neighbors = calculate_geodesic_neighbors(subject_id,
                                                              path_freesurfer,
                                                              hemi)
                np.savez(path_neighbors_filename, **dict_neighbors)


if __name__ == "__main__":
    main()
