import os

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import networkx as nx
import nibabel.freesurfer.io as fsio
import gdist
from scipy.spatial.distance import cdist
from tqdm.auto import tqdm

from src.utils import log, save_pickle


PATH_DATASETS = os.path.join(..., 'datasets')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1', 'S2', 'S3'],
    'BOLD5000': ['CSI1', 'CSI2', 'CSI3', 'CSI4'],
}


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_processed = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_processed)

    path_results_subject = os.path.join(path_dataset, 'results', subject_id)
    os.makedirs(path_results_subject, exist_ok=True)

    path_freesurfer = os.path.join(path_dataset, 'freesurfer')
    assert os.path.exists(path_freesurfer)

    for hemi in ['lh', 'rh']:
        log(f'Processing subject {subject_id} ({hemi})...')
        path_annot = os.path.join(path_freesurfer,
                                  subject_id,
                                  'label',
                                  f'{hemi}.visual.annot')
        path_mid = os.path.join(path_freesurfer,
                                subject_id,
                                'surf',
                                f'{hemi}.mid')

        path_centroids = os.path.join(path_freesurfer,
                                      subject_id,
                                      'label',
                                      f'{hemi}.mid.centroids.label')
        assert os.path.exists(path_mid)

        path_output = os.path.join(path_results_subject, f'{subject_id}_geodesic_matrix_{hemi}.pickle')
        if os.path.exists(path_output):
            log(f'Skipping {path_output}')
            continue

        labels, _, names = fsio.read_annot(path_annot)
        names = [f'{name}'.split('_')[1] for name in names]

        coords, faces = fsio.read_geometry(path_mid)
        centroid_indices = \
                nib.freesurfer.read_label(path_centroids)
        centroid_vertices = np.array([vtx for vtx in centroid_indices],
                                     dtype=np.int32)
        geodesic_matrix = np.zeros((len(centroid_vertices),
                                    len(centroid_vertices)))

        coords = coords.astype(np.float64)
        faces = faces.astype(np.int32)

        # 3. Compute geodesic dist from each centroid → all centroids
        for i, src in tqdm(enumerate(centroid_vertices), total=len(centroid_vertices)):
            # gdist.compute_gdist returns an array of length len(target_indices)
            dists = gdist.compute_gdist(
                coords,
                faces,
                source_indices=np.array([src], dtype=np.int32),
                target_indices=centroid_vertices,
            )
            geodesic_matrix[i, :] = dists

        dict_geodesic = {}
        dict_geodesic['areas'] = names
        dict_geodesic['matrix'] = geodesic_matrix

        save_pickle(dict_geodesic, path_output)

        # fig = plt.figure()
        # plt.imshow(geodesic_matrix)
        # plt.show()



def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
