import os

import nibabel as nib
import numpy as np
import networkx as nx
import nibabel.freesurfer.io as fsio
from scipy.spatial.distance import cdist

from src.utils import log


PATH_DATASETS = os.path.join(..., 'datasets')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1', 'S2', 'S3'],
    'BOLD5000': ['CSI1', 'CSI2', 'CSI3', 'CSI4'],
}


def build_surface_graph(faces, coords):
    G = nx.Graph()
    for face in faces:
        for i in range(3):
            v1, v2 = face[i], face[(i + 1) % 3]
            dist = np.linalg.norm(coords[v1] - coords[v2])
            G.add_edge(v1, v2, weight=dist)
    return G


def get_largest_component(region_vertices, faces):
    region_set = set(region_vertices)
    G = nx.Graph()
    for face in faces:
        if all(v in region_set for v in face):
            G.add_edges_from([(face[0], face[1]), (face[1], face[2]), (face[2], face[0])])
    components = list(nx.connected_components(G))
    if not components:
        return region_vertices
    return list(max(components, key=len))


def find_surface_mean_vertex(region_vertices, faces, coords):
    region_vertices = list(region_vertices)
    G = build_surface_graph(faces, coords)
    subG = G.subgraph(region_vertices)

    best_vertex = None
    best_mean_distance = np.inf

    for v in region_vertices:
        lengths = nx.single_source_dijkstra_path_length(subG, v)
        if len(lengths) < len(region_vertices):
            continue
        mean_distance = np.mean([lengths[u] for u in region_vertices if u in lengths])
        if mean_distance < best_mean_distance:
            best_mean_distance = mean_distance
            best_vertex = v

    return best_vertex


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_processed = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_processed)

    path_results_subject = os.path.join(path_dataset, 'results', subject_id)
    os.makedirs(path_results_subject, exist_ok=True)

    path_freesurfer = os.path.join(path_dataset, 'freesurfer')
    assert os.path.exists(path_freesurfer)

    for hemi in ['lh', 'rh']:
        log(f'Processing subject {subject_id} ({hemi})...')
        path_annot = os.path.join(path_freesurfer,
                                  subject_id,
                                  'label',
                                  f'{hemi}.visual.annot')
        path_mid = os.path.join(path_freesurfer,
                                subject_id,
                                'surf',
                                f'{hemi}.mid')

        path_output_centroids = os.path.join(path_freesurfer,
                                             subject_id,
                                             'label',
                                             f'{hemi}.mid.centroids.label')
        assert os.path.exists(path_annot)
        assert os.path.exists(path_mid), path_mid

        labels, _, names = fsio.read_annot(path_annot)
        coords, faces = fsio.read_geometry(path_mid)

        if not os.path.exists(path_output_centroids):
            centroid_indices = []

            for label_index in np.unique(labels):
                if label_index < 0 or label_index >= len(names):
                    continue

                name = names[label_index].decode('utf-8')
                region_vertices = np.where(labels == label_index)[0]

                if len(region_vertices) == 0:
                    log(f'Skipping {name}: no vertices.')
                    continue

                region_vertices = get_largest_component(region_vertices,
                                                        faces)

                log(f'Computing surface centroid for {name} '
                            f'({len(region_vertices)} vertices)')
                try:
                    center_idx = find_surface_mean_vertex(region_vertices,
                                                            faces,
                                                            coords)
                    center_coord = coords[center_idx]
                    centroid_indices.append((center_idx, center_coord))
                except Exception as e:
                    print(f"Failed for {name}: {e}")

            # Save all centroids as a .label file
            with open(path_output_centroids, 'w') as f:
                f.write("#!ascii label file\n")
                f.write(f"{len(centroid_indices)}\n")
                for idx, coord in centroid_indices:
                    f.write(f"{idx} {coord[0]} {coord[1]} {coord[2]} 0.0\n")
        else:
            print(f'Centroid for subject {subject_id} already exist!...\n')


def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
