import scipy.io
import numpy as np
import torch
import sys
import os
from os import path
from pathlib import Path

abs_path = str(Path(__file__).resolve())

DEBUG = False


def ensure_paths_exist(path2CachedDataset, path2ChosenRoi, path2DkIdx2LabelMappings):
    if not(path.exists(path2ChosenRoi) and path.exists(path2DkIdx2LabelMappings) and path.exists(path2CachedDataset)):
        if not path.exists(path2ChosenRoi): print(f'cannot find {path2ChosenRoi}')
        if not path.exists(path2DkIdx2LabelMappings): print(f'cannot find {path2DkIdx2LabelMappings}')
        if not path.exists(path2CachedDataset): print(f'cannot find {path2CachedDataset}')
        raise EnvironmentError('Cannot find brain data files!: ')


if 'content' in abs_path:
    path2PreprocessRepo = "/content/drive/MyDrive/brain-data-processing/"
    path2CachedDataset = path2PreprocessRepo + "all_fc_sc.mat" # "all_scans_mean_norm.mat"
    path2DkIdx2LabelMappings = path2PreprocessRepo + "desikan_killiany_cortical_index_to_label.mat"
    path2ChosenRoi = path2PreprocessRepo + "desikan_roi_zhengwu.mat"
       path2ChosenRoi = path2PreprocessRepo + "desikan_roi_zhengwu.mat"
elif 'ubuntu' in abs_path:
    path2PreprocessRepo = "/home/ubuntu/proximal_gradient_topology_inference/data/brain_data/"
    path2CachedDataset = path2PreprocessRepo + "all_fc_sc.mat"
    path2DkIdx2LabelMappings = path2PreprocessRepo + "desikan_killiany_cortical_index_to_label.mat"
    path2ChosenRoi = path2PreprocessRepo + "desikan_roi_zhengwu.mat"
else:
    print(f'\t\tSpecify where data and node labels live!!: {os.getcwd()}')
    raise ValueError('Specify where data and node labels live!')

ensure_paths_exist(path2CachedDataset, path2ChosenRoi, path2DkIdx2LabelMappings)



def load_brain_data(datatype = np.float32):
    path2data = path2CachedDataset
    brain_data_file = scipy.io.loadmat(path2data, squeeze_me=True)
    brain_data = brain_data_file['data']
    atlas = brain_data_file['atlas']
    tasktype = brain_data_file['task']
    include_subcortical = (brain_data_file['include_subcortical'] > 0)

    #chosen_roi_cortical = brain_data_file['chosen_roi']['cortical'].item()
    #chosen_roi_subcortical = brain_data_file['chosen_roi']['subcortical'].item()
    chosen_roi = scipy.io.loadmat(path2ChosenRoi, squeeze_me=True)

    if atlas == 'desikan':
        cortical_index_to_label_mapping = scipy.io.loadmat(path2DkIdx2LabelMappings, squeeze_me=True)
        cortical_index_to_label_mapping = cortical_index_to_label_mapping['desikan_killiany_cortical_index_to_label_struct']
        cortical_index_to_labels = {}
        cortical_index_to_labels['lobe'] = cortical_index_to_label_mapping['lobe'].item()
        cortical_index_to_labels['freesurfer_label_id'] = cortical_index_to_label_mapping['freesurfer_label_id'].item()

        label_names = cortical_index_to_label_mapping['label_name'].item()
        cleaned_label_names = [str(x).replace("ctx-lh", "L").replace("ctx-rh", "R").strip() for x in label_names]
        cortical_index_to_labels['label_name'] = cleaned_label_names
    else:
        raise ValueError(f'atlas {atlas} not supported. Only Desikan.')

    metadata = \
        {'idx2label_cortical': cortical_index_to_labels,
         'cortical_roi': chosen_roi['cortical'],
         'subcortical_roi': chosen_roi['subcortical'],
         'include_subcortical': include_subcortical,
         'subcortical_first': True,
         'taskype': tasktype,
         'atlas': atlas,
         }

    #return fcs, scs, subject_ids, subject_indices_in_fcs, metadata
    return brain_data, metadata


def subnetwork_masks(lobes: np.ndarray):
    for lobe in lobes:
        assert lobe in ['temporal', 'frontal', 'parietal', 'insular', 'occipital', 'None']

    temporal_mask = vector_mask_to_matrix_mask((lobes == 'temporal'))
    frontal_mask = vector_mask_to_matrix_mask((lobes == 'frontal'))
    occipital_mask = vector_mask_to_matrix_mask((lobes == 'occipital'))
    parietal_mask = vector_mask_to_matrix_mask((lobes == 'parietal'))
    #insular_mask = vector_mask_to_matrix_mask((lobes == 'insular'))
    full_mask = torch.ones((len(lobes), len(lobes)), dtype=torch.bool).view(1, len(lobes), len(lobes))

    subnetwork_masks = {'temporal': temporal_mask, 'frontal': frontal_mask, 'occipital': occipital_mask,
                        'parietal': parietal_mask, 'full': full_mask}

    return subnetwork_masks


def vector_mask_to_matrix_mask(vm):
    assert len(vm.shape)==1
    N = len(vm)
    mask = torch.zeros((N, N), dtype=torch.bool)
    for idx, idx_in_mask in enumerate(vm):
        if idx_in_mask:
            for other_idx, other_idx_in_mask in enumerate(vm):
                if other_idx_in_mask:
                    mask[idx, other_idx] = True

    return mask.view(1, N, N)


def apply_subnetwork_mask(full_network_tensor: torch.tensor, sub_network_mask: torch.tensor):
    assert full_network_tensor.shape[-1] == sub_network_mask.shape[-1], 'tensor and mask must be same size'
    assert len(sub_network_mask.shape) == 3, f'sub_network_mask must be 3D'
    assert sub_network_mask.shape[0] == 1, f'sub_network_mask must be stack of 1 slice'

    # number of nodes in subnetwork
    N_sub = torch.max(torch.sum(sub_network_mask, dim=2))
    # number of nodes in full network
    N_full = full_network_tensor.shape[-1]
    if len(full_network_tensor.shape) == 2:
        full_network_tensor = full_network_tensor.view(1, N_full, N_full)

    if len(sub_network_mask.shape) == 2:
        sub_network_mask = sub_network_mask.view(1, N_full, N_full)

    # batch_dim of mask must be 1
    batch_size = full_network_tensor.shape[0]
    assert sub_network_mask.shape[0] == 1, f'batch_dim of sub_net_mask must be 1: {sub_network_mask}'

    # expand sub_network_mask to match batch dim of full_net_tensor
    # only keep elements in mask. Now we have 1 long row of elements
    sub_network = full_network_tensor[torch.broadcast_to(sub_network_mask, (batch_size, N_full, N_full))]

    # rely on consistant ordering: all elements in mask and in the 0th slice are at the beginning, and within
    # those the first row is first and in propoer order
    sub_network = sub_network.view(batch_size, N_sub, N_sub)
    return sub_network


if __name__ == "__main__":
    brain_data, metadata = load_brain_data(datatype=np.float32)
    sub_network_masks = subnetwork_masks(metadata["idx2label_cortical"]["lobe"])


    mask = torch.tensor([[ True,  True, False,  True],
        [ True,  True, False,  True],
        [False, False, False, False],
        [ True,  True, False,  True]])
    b = torch.rand(2,4,4)
    out = apply_subnetwork_mask(b, mask)
