import os

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.stats import zscore
import nibabel.freesurfer.io as fsio
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.linear_model import Ridge
from src.utils import log, save_pickle, load_pickle

PATH_DATASETS = os.path.join(..., 'Documents', 'datasets')
# PATH_DATASETS = os.path.join(..., 'datasets')
PATH_IMAGES = os.path.join('images')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1', 'S2', 'S3'],
    'BOLD5000': [],
}
SPLITS = ['train']
LIST_NEIGHBORS = [50]
LIST_CONCEPTS = ['heavy', 'moves', 'natural', 'size']


def zscore_betas(betas, df_stimuli):
    indices_session = df_stimuli['Session'].values
    betas_zscore = np.zeros_like(betas)
    sessions = np.unique(indices_session)
    num_sessions = len(sessions)
    for i_session in tqdm(sessions, total=num_sessions):
        indices = np.where(indices_session == i_session)[0]
        betas_zscore[indices] = zscore(betas[indices], axis=0)
    return betas_zscore


def _vertex_score(idx, X, y, neighbors, sessions, labels):
    """Compute the mean L-O-G-O R² for one vertex (or nan if invalid)."""
    if labels[idx] == -1:
        return idx, np.nan

    # neighborhood time-series
    cols = neighbors[idx]
    X_sub = X[:, cols]

    # leave-one-session-out CV
    loo = LeaveOneGroupOut()
    model = Ridge()
    scores = []
    for tr, te in loo.split(X_sub, y, groups=sessions):
        model.fit(X_sub[tr], y[tr])
        scores.append(model.score(X_sub[te], y[te]))

    return idx, float(np.mean(scores))


def run_searchlight(
    X: np.ndarray,
    y: np.ndarray,
    neighbors: list,
    sessions: np.ndarray,
    labels: np.ndarray,
    n_jobs: int = -1
) -> np.ndarray:
    """
    Run a searchlight across all vertices.
    Returns an array of R² scores (nan where labels == -1).
    """
    # preallocate output
    n_vert = labels.shape[0]
    scores = np.full(n_vert, np.nan, dtype=np.float32)

    # only iterate valid vertices
    valid = np.where(labels != -1)[0]

    results = Parallel(n_jobs=n_jobs)(
        delayed(_vertex_score)(v, X, y, neighbors, sessions, labels)
        for v in tqdm(valid, total=len(valid))
    )
    for idx, val in results:
        scores[idx] = val

    return scores


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_results = os.path.join(path_dataset, 'results', subject_id)
    assert os.path.exists(path_results)

    path_subject = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_subject)

    path_stimuli = os.path.join(path_dataset,
                                'processed',
                                subject_id,
                                'stimuli',
                                f'{subject_id}_stimuli.csv')
    assert os.path.exists(path_stimuli)
    df_stimuli = pd.read_csv(path_stimuli)

    path_results_components = os.path.join(path_results, 'components')
    os.makedirs(path_results_components, exist_ok=True)
    d = load_pickle(f'../images/properties_THINGS.pickle')

    for split in SPLITS:
        for num_neighbors in LIST_NEIGHBORS:
            for name_property in LIST_CONCEPTS:
                path_output = os.path.join(path_results, 'properties', f'{subject_id}_properties_{split}_{num_neighbors}_{name_property}.pickle')

                if os.path.exists(path_output):
                    print(f'Skipping for {name_property} {num_neighbors} {split}')
                    continue

                indices_stimuli = df_stimuli[df_stimuli['Split'] == split]

                concepts = d['concepts']

                array_properties = np.zeros((len(indices_stimuli),))
                unique_concepts = indices_stimuli['Concept'].unique()

                for i, concept in enumerate(unique_concepts):
                    mask = (indices_stimuli['Concept'] == concept).values
                    array_properties[mask] = d['properties'][name_property]['mean'][np.where(concepts == concept)[0]]

                array_properties = zscore(array_properties)

                decoding_scores = {}

                for hemi in ['lh', 'rh']:
                    path_neighbors = os.path.join(path_dataset, 'processed', subject_id, 'neighbors', 'surf', f'{subject_id}_neighbors_surf_{hemi}.npz')
                    dict_neighbors = np.load(path_neighbors)
                    indices_neighbors = dict_neighbors['indices'][:, :num_neighbors]

                    indices_sessions = df_stimuli[df_stimuli['Split'] == split]['Session'].values
                    indices_split = df_stimuli[df_stimuli['Split'] == split]['Trial'].values

                    path_betas = os.path.join(path_subject, 'betas', 'surf', f'{subject_id}_betas_surf_{hemi}.npy')
                    assert os.path.exists(path_betas), path_betas

                    betas = zscore_betas(np.load(path_betas), df_stimuli)
                    assert ~np.isnan(betas).any()
                    betas = betas[indices_split, :]

                    path_visual_annot = os.path.join(path_dataset, 'freesurfer', subject_id, 'label', f'{hemi}.visual.annot')
                    labels_vertices, _, _ = fsio.read_annot(path_visual_annot)

                    decoding_scores[hemi] = run_searchlight(
                        betas,
                        array_properties,
                        indices_neighbors,
                        indices_sessions,
                        labels_vertices,
                        n_jobs=-1
                    )

                    save_pickle(decoding_scores, path_output)


def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
