import os

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.stats import zscore
from sklearn.preprocessing import LabelEncoder
import nibabel.freesurfer.io as fsio
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import LeaveOneGroupOut, cross_val_score
from sklearn.linear_model import Ridge
from src.utils import log, save_pickle

PATH_DATASETS = os.path.join(..., 'Documents', 'datasets')
# PATH_DATASETS = os.path.join(..., 'datasets')
PATH_IMAGES = os.path.join('images')

DICT_DATASETS = {
    'THINGS-fMRI': ['S1', 'S2', 'S3'],
    'BOLD5000': [],
}
SPLITS = ['train']
LIST_NEIGHBORS = [50]
LIST_COMPONENTS = [0, 1, 2, 3]


def zscore_betas(betas, df_stimuli):
    indices_session = df_stimuli['Session'].values
    betas_zscore = np.zeros_like(betas)
    sessions = np.unique(indices_session)
    num_sessions = len(sessions)
    for i_session in tqdm(sessions, total=num_sessions):
        indices = np.where(indices_session == i_session)[0]
        betas_zscore[indices] = zscore(betas[indices], axis=0)
    return betas_zscore


def _vertex_score(idx, X, y, neighbors, sessions, labels):
    """Compute the mean L-O-G-O R² for one vertex (or nan if invalid)."""
    if labels[idx] == -1:
        return idx, np.nan

    # neighborhood time-series
    cols = neighbors[idx]
    X_sub = X[:, cols]

    # leave-one-session-out CV
    loo = LeaveOneGroupOut()
    model = Ridge()
    scores = []
    for tr, te in loo.split(X_sub, y, groups=sessions):
        model.fit(X_sub[tr], y[tr])
        scores.append(model.score(X_sub[te], y[te]))

    return idx, float(np.mean(scores))


def run_searchlight(
    X: np.ndarray,
    y: np.ndarray,
    neighbors: list,
    sessions: np.ndarray,
    labels: np.ndarray,
    n_jobs: int = -1
) -> np.ndarray:
    """
    Run a searchlight across all vertices.
    Returns an array of R² scores (nan where labels == -1).
    """
    # preallocate output
    n_vert = labels.shape[0]
    scores = np.full(n_vert, np.nan, dtype=np.float32)

    # only iterate valid vertices
    valid = np.where(labels != -1)[0]

    results = Parallel(n_jobs=n_jobs)(
        delayed(_vertex_score)(v, X, y, neighbors, sessions, labels)
        for v in tqdm(valid, total=len(valid))
    )
    for idx, val in results:
        scores[idx] = val

    return scores


def process_subject(dataset, subject_id):
    path_dataset = os.path.join(PATH_DATASETS, dataset)
    assert os.path.exists(path_dataset)

    path_results = os.path.join(path_dataset, 'results', subject_id)
    assert os.path.exists(path_results)

    path_subject = os.path.join(path_dataset, 'processed', subject_id)
    assert os.path.exists(path_subject)

    path_stimuli = os.path.join(path_dataset,
                                'processed',
                                subject_id,
                                'stimuli',
                                f'{subject_id}_stimuli.csv')
    assert os.path.exists(path_stimuli)
    df_stimuli = pd.read_csv(path_stimuli)

    path_results_components = os.path.join(path_results, 'components')
    os.makedirs(path_results_components, exist_ok=True)
    
    for split in SPLITS:
        log(f'Split: {split}')
        path_images_pca = os.path.join('..', 'images', 'subjects', f'images_{split}_pca_{subject_id}.pickle')
        assert os.path.exists(path_images_pca)

        dict_images_pca = pd.read_pickle(path_images_pca)

        for num_neighbors in LIST_NEIGHBORS:
            log(f'Num neighbors: {num_neighbors}')
            for component in LIST_COMPONENTS:
                log(f'Component: {component}')
                path_output = os.path.join(path_results, 'components', f'{subject_id}_components_{split}_{num_neighbors}_{component}.pickle')

                if os.path.exists(path_output):
                    log(f'Skipping for {component} {num_neighbors} {split}')
                    continue

                label_component = dict_images_pca['images_transformed'][:, component]

                decoding_scores = {}

                for hemi in ['lh', 'rh']:
                    log(f'Processing hemisphere: {hemi}')
                    path_neighbors = os.path.join(path_dataset, 'processed', subject_id, 'neighbors', 'surf', f'{subject_id}_neighbors_surf_{hemi}.npz')
                    dict_neighbors = np.load(path_neighbors)
                    indices_neighbors = dict_neighbors['indices'][:, :num_neighbors]

                    indices_sessions = df_stimuli[df_stimuli['Split'] == split]['Session'].values
                    indices_split = df_stimuli[df_stimuli['Split'] == split]['Trial'].values

                    path_betas = os.path.join(path_subject, 'betas', 'surf', f'{subject_id}_betas_surf_{hemi}.npy')
                    assert os.path.exists(path_betas), path_betas

                    betas = zscore_betas(np.load(path_betas), df_stimuli)
                    assert ~np.isnan(betas).any()
                    betas = betas[indices_split, :]

                    path_visual_annot = os.path.join(path_dataset, 'freesurfer', subject_id, 'label', f'{hemi}.visual.annot')
                    labels_vertices, _, _ = fsio.read_annot(path_visual_annot)

                    log(f'Valid labels: {np.sum(labels_vertices != -1)}')
                    log('Decoding component...')
                    decoding_scores[hemi] = run_searchlight(
                        betas,
                        label_component,
                        indices_neighbors,
                        indices_sessions,
                        labels_vertices,
                        n_jobs=-1
                    )

                    save_pickle(decoding_scores, path_output)


def main():
    assert os.path.exists(PATH_DATASETS)

    for dataset, subject_ids in DICT_DATASETS.items():
        for subject_id in subject_ids:
            log(f'Processing {dataset} {subject_id}...')
            process_subject(dataset, subject_id)


if __name__ == "__main__":
    main()
