import os

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.stats import zscore
from sklearn.preprocessing import LabelEncoder
import nibabel.freesurfer.io as fsio
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score

from src.utils import log, save_pickle

# PATH_DATASETS = os.path.join(..., 'Documents', 'datasets')
PATH_DATASETS = os.path.join(..., 'datasets')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1'],
    'BOLD5000': [],
}

SPLITS = ['train']

NUM_NEIGHBORS = [100]


def zscore_betas(betas, df_stimuli):
    indices_session = df_stimuli['Session'].values
    betas_zscore = np.zeros_like(betas)
    sessions = np.unique(indices_session)
    num_sessions = len(sessions)
    for i_session in tqdm(sessions, total=num_sessions):
        indices = np.where(indices_session == i_session)[0]
        betas_zscore[indices] = zscore(betas[indices], axis=0)
    return betas_zscore


def _calculate_accuracy_vertex(idx_vertex,
                               data,
                               labels_vertices,
                               indices_neighbors,
                               precomputed_folds,
                               labels_trials):
    if labels_vertices[idx_vertex] == -1:
        return idx_vertex, np.nan

    neigh = indices_neighbors[idx_vertex]
    X = data[:, neigh]
    y = labels_trials

    clf = LogisticRegression(
        solver='lbfgs',
        max_iter=1000,
    )

    scores = cross_val_score(
        clf,
        X, y,
        cv=precomputed_folds,
        scoring='accuracy',
        n_jobs=-1
    )

    return idx_vertex, np.mean(scores)


def calculate_accuracy_vertices(data,
                                labels_vertices,
                                indices_neighbors,
                                indices_sessions,
                                labels_trials,
                                n_jobs=-1):
    num_vertices = labels_vertices.shape[0]
    decoding_scores = np.full(num_vertices, np.nan, dtype=np.float32)
    valid_vs = np.where(labels_vertices != -1)[0]

    logo = LeaveOneGroupOut().split(X=np.zeros(len(labels_trials)),
                                    y=labels_trials,
                                    groups=indices_sessions)
    precomputed_folds = list(logo)

    results = Parallel(n_jobs=1)(
         delayed(_calculate_accuracy_vertex)(
             v, data, labels_vertices,
             indices_neighbors, precomputed_folds, labels_trials
         )
         for v in tqdm(valid_vs, total=len(valid_vs))
    )

    for v, acc in results:
        decoding_scores[v] = acc

    return decoding_scores


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_processed = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_processed)

    path_results_subject = os.path.join(path_dataset, 'results', subject_id)
    os.makedirs(path_results_subject, exist_ok=True)

    path_stimuli = os.path.join(path_processed,
                                'stimuli',
                                f'{subject_id}_stimuli.csv')
    df_stimuli = pd.read_csv(path_stimuli)

    for split in SPLITS:
        log(f'Split: {split}')
        indices_split = df_stimuli[df_stimuli['Split'] == split]['Trial'].values
        indices_sessions = df_stimuli[df_stimuli['Split'] == split]['Session'].values

        if split == 'train':
            stimuli_split = df_stimuli['Concept'][indices_split].values
        elif split == 'test':
            stimuli_split = df_stimuli['Stimulus'][indices_split].values

        unique_classes = np.unique(stimuli_split)
        num_classes = len(unique_classes)
        if num_classes >= 100:
            log(f'Subsampling classes (100)...')
            np.random.seed(4321)
            subset_unique_classes = np.random.choice(unique_classes, size=100, replace=False)
            indices_subset = np.isin(stimuli_split, subset_unique_classes)
            indices_split = indices_split[indices_subset]
            indices_sessions = indices_sessions[indices_subset]
            stimuli_split = stimuli_split[indices_subset]

        le = LabelEncoder()
        labels_trials = le.fit_transform(stimuli_split)
        log(f'Labels: {len(np.unique(labels_trials))}')

        for num_neighbors in NUM_NEIGHBORS:
            log(f'Neighbors: {num_neighbors}')
            for hemi in ['lh', 'rh']:
                log(f'Starting hemisphere: {hemi}')

                path_output = \
                    os.path.join(path_results_subject,
                                 f'{subject_id}_decode_stimuli_{hemi}_{split}_{num_neighbors}.npy')
                if os.path.exists(path_output):
                    log(f'File {path_output} already exists. Skipping...\n')
                    continue

                path_neighbors = \
                    os.path.join(path_processed,
                                'neighbors',
                                'surf',
                                f'{subject_id}_neighbors_surf_{hemi}.npz')
                assert os.path.exists(path_neighbors)

                log(f'Loading neighbors...')
                neighbors = np.load(path_neighbors)
                indices_neighbors = neighbors['indices'][:, :num_neighbors-1]

                path_betas = \
                    os.path.join(path_processed,
                                 'betas',
                                 'surf',
                                 f'{subject_id}_betas_surf_{hemi}.npy')
                assert os.path.exists(path_betas), path_betas

                log('Loading betas...')
                betas = np.load(path_betas)
                betas = zscore_betas(betas, df_stimuli)
                betas = betas[indices_split]
                log(f'Betas shape: {betas.shape}')

                path_visual_annot_lh = os.path.join(path_dataset,
                                                    'freesurfer',
                                                    subject_id,
                                                    'label',
                                                    f'{hemi}.visual.annot')
                labels_vertices, _, _ = fsio.read_annot(path_visual_annot_lh)
                log(f'Valid labels: {np.sum(labels_vertices != -1)}')

                accuracies = calculate_accuracy_vertices(
                        betas,
                        labels_vertices,
                        indices_neighbors,
                        indices_sessions,
                        labels_trials,
                        n_jobs=-1
                    )

                log('Saving results...\n')
                np.save(path_output, accuracies)


def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
