import os

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.stats import zscore
from skdim.id import lPCA, MLE, TwoNN, CorrInt, FisherS
import nibabel.freesurfer.io as fsio

from src.utils import log, save_pickle

# PATH_DATASETS = os.path.join(..., 'Documents', 'datasets')
PATH_DATASETS = os.path.join(..., 'datasets')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1', 'S2', 'S3'],
    'BOLD5000': ['CSI1', 'CSI2', 'CSI3', 'CSI4'],
}

ESTIMATORS = {
    'lpca': lPCA(ver='participation_ratio'),
    'mle': MLE(),
    # 'twoNN': TwoNN(),
    # 'corrInt': CorrInt(),
    # 'fisherS': FisherS()
}

SPLITS = ['train']

NUM_NEIGHBORS = [50]


def zscore_betas(betas, df_stimuli):
    assert np.isnan(betas).sum() == 0, 'Betas contain NaN values'
    indices_session = df_stimuli['Session'].values
    betas_zscore = np.zeros_like(betas)
    sessions = np.unique(indices_session)
    num_sessions = len(sessions)
    for i_session in tqdm(sessions, total=num_sessions):
        indices = np.where(indices_session == i_session)[0]
        betas_zscore[indices] = zscore(betas[indices], axis=0)
    betas_zscore = np.nan_to_num(betas_zscore)
    return betas_zscore


def compute_dimensionality_for_vertex(idx_vertex,
                                      data,
                                      indices_neighbors,
                                      labels_vertices,
                                      method):
    if labels_vertices[idx_vertex] == -1:
        return idx_vertex, np.nan

    neighbors = indices_neighbors[idx_vertex]
    X_vertex = data[:, neighbors]
    if np.isnan(X_vertex).any():
        return idx_vertex, np.nan
    dim = method.fit(X_vertex).dimension_

    return idx_vertex, dim


def calculate_dimensionality_vertices(data,
                                      labels_vertices,
                                      indices_neighbors,
                                      method,
                                      n_jobs=-1):
    num_vertices = len(labels_vertices)
    dimensionality = np.full(num_vertices, np.nan, dtype=np.float32)

    valid_vertices = np.where(labels_vertices != -1)[0]

    results = Parallel(n_jobs=n_jobs)(
        delayed(compute_dimensionality_for_vertex)(
            v, data, indices_neighbors, labels_vertices, method
        )
        for v in tqdm(valid_vertices, desc=method.__class__.__name__)
    )

    for v, dim in results:
        dimensionality[v] = dim

    return dimensionality


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_processed = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_processed)

    path_results_subject = os.path.join(path_dataset, 'results', subject_id, 'dimensionality')
    os.makedirs(path_results_subject, exist_ok=True)

    path_stimuli = os.path.join(path_processed,
                                'stimuli',
                                f'{subject_id}_stimuli.csv')
    df_stimuli = pd.read_csv(path_stimuli)

    for split in SPLITS:
        log(f'Split: {split}')
        indices_split = df_stimuli[df_stimuli['Split'] == split]['Trial'].values

        for num_neighbors in NUM_NEIGHBORS:
            log(f'Neighbors: {num_neighbors}')
            for name, estimator in ESTIMATORS.items():
                log(f'Running estimator {name}')
                dict_dims = {}
                path_output = \
                        os.path.join(path_results_subject,
                                    f'{subject_id}_dimensionality_{split}_{num_neighbors}_{name}.pickle')
                if os.path.exists(path_output):
                    log(f'File {path_output} already exists. Skipping...\n')
                    continue

                for hemi in ['lh', 'rh']:
                    log(f'Starting hemisphere: {hemi}')
                    path_neighbors = \
                        os.path.join(path_processed,
                                    'neighbors',
                                    'surf',
                                    f'{subject_id}_neighbors_surf_{hemi}.npz')
                    assert os.path.exists(path_neighbors)

                    log(f'Loading neighbors...')
                    neighbors = np.load(path_neighbors)
                    indices_neighbors = neighbors['indices'][:, :num_neighbors]

                    path_betas = \
                        os.path.join(path_processed,
                                    'betas',
                                    'surf',
                                    f'{subject_id}_betas_surf_{hemi}.npy')
                    assert os.path.exists(path_betas), path_betas

                    log('Loading betas...')
                    betas = np.load(path_betas)
                    betas = zscore_betas(betas, df_stimuli)
                    betas = betas[indices_split]
                    log(f'Betas shape: {betas.shape}')

                    path_visual_annot_lh = os.path.join(path_dataset,
                                                        'freesurfer',
                                                        subject_id,
                                                        'label',
                                                        f'{hemi}.visual.annot')
                    labels, _, _ = fsio.read_annot(path_visual_annot_lh)
                    log(f'Valid labels: {np.sum(labels != -1)}')

                    dict_dims[hemi] = calculate_dimensionality_vertices(
                        betas,
                        labels,
                        indices_neighbors,
                        estimator,
                        n_jobs=-1
                    )

                log('Saving results...\n')
                save_pickle(dict_dims, path_output)


def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
