"""Functions for training a regression model from fmri to music embeddings."""

import dataclasses
import functools
from typing import Callable, Generic, TypeVar

import numpy as np
from himalaya.ridge import RidgeCV
from himalaya.scoring import correlation_score
from sklearn.model_selection import KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from fmri2music import emb_loader, fmri_loader, hparam_tuner, utils
from fmri2music.predictor import OnlinePredictor

# For some embeddings there is a discrepancy between the name of the training embedding
# and the name of the embedding that they are used to predict.
# MuLan text embedding, for example, is available for GTZAN data but not for the FMA
# (for which we'd need it when retrieving).
TRAIN_PRED_EMB_NAME_MAP = {
    "window15s-stride15s-mv101txt-avg": "window15s-stride15s-mv101-avg"
}


T = TypeVar("T")
U = TypeVar("U")


@dataclasses.dataclass
class TrnVal(Generic[T]):
    """Tuple of something existing for training and evaluation data.

    This may be a trained model, a path, or data, for example."""

    trn: T
    val: T

    def map_fn(self, f: Callable[[T], U]) -> "TrnVal[U]":
        """Applies f to trn and val."""
        return TrnVal(trn=f(self.trn), val=f(self.val))


def validate_stim_names(
    fmri_stim_names: list[str], music_emb_stim_names: list[str]
) -> None:
    music_emb_stim_names = set(
        utils.parse_long_key(name)[0] for name in music_emb_stim_names
    )
    fmri_stim_names = set(fmri_stim_names)

    # Check whether the two sets are not the same and print the diff.
    if music_emb_stim_names != fmri_stim_names:
        print("Stim names from fmri and music embeddings differ.")
        print(f"Sizes: {len(fmri_stim_names)=}, {len(music_emb_stim_names)=}")
        print("fmri - music: ", fmri_stim_names - music_emb_stim_names)
        print("music - fmri: ", music_emb_stim_names - fmri_stim_names)
        raise ValueError("Stim names from fmri and music embeddings differ.")
    else:
        print("Stim names from fmri and music embeddings are the same.")


def align_resp_to_stim(
    resp: TrnVal[np.ndarray], stim_names: TrnVal[list[str]], delay: int
) -> TrnVal[np.ndarray]:
    """Align the response with the available stimulus (music embeddings).

    Target embeddings (e.g. from MuLan) are available for certain time slices.
    In order to predict them, we average the corresponding fMRI responses along
    the time dimension. For example, if we have a music embedding for the music
    segment from 0s to 3s, we average the first two fMRI response vectors (which
    correspond to the times 0s-1.5s and 1.5s-3s).

    In cases where the time slices are not aligned, we pick the temporally closest
    fMRI data.

    The delay (haemodynamic response delay) is in unit 'fmri slices'.

    Returns the aligned, averaged responses.
    """

    # The response to stimuli is delayed. To align we rotate the response matrix,
    # i.e., move later time steps forward. The overflow *is* valid, because of the
    # way the data was collected.
    rotate_fn = functools.partial(utils.rotate_rows, n=delay)
    resp = resp.map_fn(rotate_fn)

    # Number of fMRI scans per music clip (15s duration).
    fmri_scans_per_stim = 10

    stim_name_groups = stim_names.map_fn(utils.group_stim_names)

    seconds_to_idx_map = {}  # Collected only to be printed.
    resp_mean = TrnVal(trn=[], val=[])  # Temporally averaged response.

    # Loop over clips (15s each).
    for i, slice_names in enumerate(stim_name_groups.trn):
        clip_offset = i * fmri_scans_per_stim
        clip_resp = resp.trn[clip_offset : clip_offset + fmri_scans_per_stim]

        # Loop over slices of a clip (varying length, depending on emb type).
        for _, start_s, end_s in map(utils.parse_long_key, slice_names):
            start_idx, end_idx = utils.seconds_to_idx(
                start_s, end_s, fmri_scans_per_stim
            )
            clip_slice = clip_resp[start_idx:end_idx]
            resp_mean.trn.append(np.mean(clip_slice, axis=0))
            seconds_to_idx_map[(start_s, end_s)] = (start_idx, end_idx)

    # Loop over clips (15s each).
    for i, slice_names in enumerate(stim_name_groups.val):
        clip_offset = i * fmri_scans_per_stim
        clip_resp = resp.val[clip_offset : clip_offset + fmri_scans_per_stim]

        # Loop over slices of a clip (varying length, depending on emb type).
        for _, start_s, end_s in map(utils.parse_long_key, slice_names):
            start_idx, end_idx = utils.seconds_to_idx(
                start_s, end_s, fmri_scans_per_stim
            )
            clip_slice = clip_resp[start_idx:end_idx]
            resp_mean.val.append(np.mean(clip_slice, axis=0))
            seconds_to_idx_map[(start_s, end_s)] = (start_idx, end_idx)

    print(f"Seconds to response indices: {sorted(seconds_to_idx_map.items())}")

    resp_mean = resp_mean.map_fn(np.array)
    print(f"Response shapes: trn={resp_mean.trn.shape}, val={resp_mean.val.shape}")
    return resp_mean


def prepare_roi(subject_name: str) -> tuple[str, np.ndarray]:
    fs_roi = fmri_loader.load_fs_roi(subject_name)

    labels = fs_roi["fsROI"][0][0][1]
    labels = [l[0][0] for l in labels]

    voxels = fs_roi["fsROI"][0][0][2]
    voxels = [v[0][:, 0] for v in voxels]

    # Cerebral Cortex voxels
    vset = fmri_loader.load_vset(subject_name)
    tvoxels = vset["tvoxels"]

    # Use only cortex
    ctx_voxels = []
    for l, v in zip(labels, voxels):
        if l[:3] == ("ctx"):
            ind = np.where(np.in1d(tvoxels, v) == 1)[0]
            ctx_voxels.append((l, ind))

    return ctx_voxels


def load_stimulus(emb_name: str) -> tuple[TrnVal[np.ndarray], TrnVal[list[str]]]:
    stimname_trn, stimname_val = fmri_loader.load_experiment_order()

    emb_from_gtzan_key = emb_loader.get_gtzan_emb(emb_name)

    validate_stim_names(stimname_trn + stimname_val, list(emb_from_gtzan_key.keys()))

    def get_keys_w_prefix(prefix: str) -> list[str]:
        return sorted([k for k in emb_from_gtzan_key.keys() if k.startswith(prefix)])

    stimarr_trn = []
    stimname_trn_full = []
    for cstim in stimname_trn:
        for k in get_keys_w_prefix(cstim):
            stimarr_trn.append(emb_from_gtzan_key[k])
            stimname_trn_full.append(k)
    stimarr_trn = np.array(stimarr_trn)

    stimarr_val = []
    stimname_val_full = []
    for cstim in stimname_val:
        for k in get_keys_w_prefix(cstim):
            stimarr_val.append(emb_from_gtzan_key[k])
            stimname_val_full.append(k)
    stimarr_val = np.array(stimarr_val)

    stimarr = TrnVal(trn=stimarr_trn, val=stimarr_val)
    stimname = TrnVal(trn=stimname_trn_full, val=stimname_val_full)

    return stimarr, stimname


def train_regressor_with_cross_validation(
    stim: TrnVal[np.ndarray],
    resp: TrnVal[np.ndarray],
    ctx_voxels: list[tuple[str, np.ndarray]],
    hparams: hparam_tuner.HParams,
    num_splits: int,
) -> tuple[TrnVal[np.ndarray], TrnVal[float]]:
    """Trains a regressor with cross validation.

    Returned trn predictions are made without having been seen during training."""

    k_fold = KFold(n_splits=num_splits, shuffle=True, random_state=20170915)
    result_trn_pred = np.zeros_like(stim.trn)

    result_corr = []
    for train_indices, test_indices in k_fold.split(stim.trn):
        kf_pred, corr = train_regressor(
            stim=TrnVal(trn=stim.trn[train_indices], val=stim.trn[test_indices]),
            resp=TrnVal(trn=resp.trn[train_indices], val=resp.trn[test_indices]),
            ctx_voxels=ctx_voxels,
            hparams=hparams,
        )

        # Here the 'val' result is the validation split of the training data in the
        # current cross validation iteration.
        result_trn_pred[test_indices] = kf_pred.val

        result_corr.append(corr.val)

    result_corr = np.mean(result_corr).item()

    # Run one training with all train data to obtain prediction for the original validation split.
    final_preds, final_corr = train_regressor(stim, resp, ctx_voxels, hparams)

    return (
        TrnVal(trn=result_trn_pred, val=final_preds.val),
        TrnVal(trn=result_corr, val=final_corr.val),
    )


def train_regressor(
    stim: TrnVal[np.ndarray],
    resp: TrnVal[np.ndarray],
    ctx_voxels: list[tuple[str, np.ndarray]],
    hparams: hparam_tuner.HParams,
) -> tuple[TrnVal[np.ndarray], TrnVal[float]]:
    """Train a regressor for each brain region and combine into an ensemble."""

    # Hparams are tuned by the RidgeCV class with cross validation.
    alpha = np.array([10000000 * (0.1**i) for i in range(20)])
    ncv = 5

    ridge = RidgeCV(
        alphas=alpha, cv=ncv, solver_params={"score_func": correlation_score}
    )
    preprocess_pipeline = make_pipeline(
        StandardScaler(with_mean=True, with_std=False),
    )
    pipeline = make_pipeline(
        preprocess_pipeline,
        ridge,
    )

    y_trn, y_val = stim.trn, stim.val

    # Collect cross-validation scores per brain region for later ensemble.
    cv_scores = []

    selected_voxels = ctx_voxels[: hparams.get("voxel_num_limit", len(ctx_voxels))]
    for croi_name, croi_voxels in selected_voxels:
        x_trn = resp.trn[:, croi_voxels]
        x_val = resp.val[:, croi_voxels]
        pipeline.fit(x_trn, y_trn)
        y_val_pred = pipeline.predict(x_val)
        cv_scores.append(np.mean(ridge.cv_scores_))

    y_preds = TrnVal([], [])
    nroi_used_for_ensemble = hparams["nroi_used_for_ensemble"]
    print(f"Training final ensemble based on top {nroi_used_for_ensemble} ROIs.")
    for roiidx in np.argsort(cv_scores)[-nroi_used_for_ensemble:]:
        croi_name, croi_voxels = ctx_voxels[roiidx]
        x_trn = resp.trn[:, croi_voxels]
        x_val = resp.val[:, croi_voxels]
        pipeline.fit(x_trn, y_trn)
        y_trn_pred = pipeline.predict(x_trn)
        y_preds.trn.append(y_trn_pred)
        y_val_pred = pipeline.predict(x_val)
        y_preds.val.append(y_val_pred)

        r_trn = correlation_score(y_trn_pred, y_trn)
        r_val = correlation_score(y_val_pred, y_val)
        print(
            f"r_trn={np.mean(r_trn)} r_val={np.mean(r_val)} (std trn={np.std(r_trn)} val={np.std(r_val)}), {croi_name=} roi_dim={x_trn.shape[1]}"
        )

    y_preds = y_preds.map_fn(np.array)
    # Mean across different ROIs.
    y_preds_ens = y_preds.map_fn(lambda m: np.mean(m, axis=0))
    y_preds_ens_norm = y_preds_ens.map_fn(utils.normalize_preds)

    correlation_scores = TrnVal(
        trn=correlation_score(y_preds_ens_norm.trn, y_trn),
        val=correlation_score(y_preds_ens_norm.val, y_val),
    ).map_fn(np.mean)
    predictions = TrnVal(trn=y_preds_ens_norm.trn, val=y_preds_ens_norm.val)

    return predictions, correlation_scores


def exclude_category_from_trn(
    stim: TrnVal[np.ndarray],
    resp: TrnVal[np.ndarray],
    stim_names: list[str],
    excluded_genre: str,
) -> tuple[TrnVal[np.ndarray], TrnVal[np.ndarray]]:
    """Leave one category out, for category generalization analysis."""

    exclude_index = [
        stim_name.split(".")[0] == excluded_genre for stim_name in stim_names
    ]
    stim.trn = stim.trn[~np.array(exclude_index), :]
    resp.trn = resp.trn[~np.array(exclude_index), :]
    stim_names = [
        stim_name
        for stim_name, exclude in zip(stim_names, exclude_index)
        if not exclude
    ]

    return stim, resp, stim_names


def train_predictor(
    name: str,
    emb_name: str,
    subject_name: str,
    num_xval_splits: int,
    hparams: hparam_tuner.HParams,
) -> OnlinePredictor:
    """Runs a training session and returns GTZAN predictor object."""
    print("Loading stimulus data...")
    stim, stim_names = load_stimulus(emb_name)
    print("Loading response data...")
    resp_raw = fmri_loader.load_resp_data(subject_name)
    resp = align_resp_to_stim(
        TrnVal(trn=resp_raw[0], val=resp_raw[1]),
        stim_names,
        delay=hparams["haemodynamic_resp_delay"],
    )

    excluded_genre = hparams.get("excluded_genre", None)
    if excluded_genre:
        len_stim_name_org = len(stim_names.trn)
        print(f"Exclude {excluded_genre} category from training data...")
        stim, resp, stim_names_trn = exclude_category_from_trn(
            stim, resp, stim_names.trn, excluded_genre
        )
        stim_names.trn = stim_names_trn
        name = name + "_ex_" + excluded_genre
        if len_stim_name_org == len(stim_names.trn):
            raise ValueError("No change from original stimulus.")

    print("Loading ROI data...")
    ctx_voxels = prepare_roi(subject_name)

    if num_xval_splits > 1:
        print(f"Training with xval (num_splits={num_xval_splits})...")
        preds, _ = train_regressor_with_cross_validation(
            stim, resp, ctx_voxels, hparams, num_xval_splits
        )
    else:
        print("Training...")
        preds, corr = train_regressor(stim, resp, ctx_voxels, hparams)
        print(f"{corr=}")

    if emb_name in TRAIN_PRED_EMB_NAME_MAP:
        pred_emb_name = TRAIN_PRED_EMB_NAME_MAP[emb_name]
    else:
        pred_emb_name = emb_name

    all_stim_names = stim_names.trn + stim_names.val
    all_preds = np.concatenate([preds.trn, preds.val], axis=0)
    return OnlinePredictor(name, pred_emb_name, all_stim_names, all_preds)
