"""fMRI data loading utilities."""

import functools
import os

import h5py
import numpy as np
import scipy.io

from fmri2music import utils


def get_data_dir() -> str:
    """Returns the path to the data directory from the DATA_DIR env var.

    The data dir is prepended to paths from which embeddings and predictions are loaded.
    """
    data_dir = os.getenv("DATA_DIR", os.path.join(".", "data"))
    if not data_dir or not os.path.isdir(data_dir):
        raise ValueError(
            f"Provided DATA_DIR does not exist: '{data_dir}'. Set it in the .env file."
        )
    return data_dir


@utils.synchronized
@functools.lru_cache(maxsize=16)
def load_resp_data(subject_name: str) -> tuple[np.ndarray]:
    """Loads the response data for a given subject.

    This function loads a ~1GB file into RAM. It may crash if you don't have enough RAM.
    """
    resp_data_path = f"{get_data_dir()}/resp/{subject_name}_RespData.mat"
    print(f"Response data path: {resp_data_path}")
    with h5py.File(resp_data_path, "r") as f:
        resp_trn: np.ndarray = f["RespData"]["respTrn"]
        resp_val: np.ndarray = f["RespData"]["respVal"]
        resp_trn = resp_trn[...].T
        resp_val = resp_val[...].T
        print(resp_trn.shape, resp_val.shape)
        return resp_trn, resp_val


@utils.synchronized
@functools.lru_cache(maxsize=5)
def load_fs_roi(subject_name: str) -> dict:
    return scipy.io.loadmat(f"{get_data_dir()}/MRI/fsROI/{subject_name}/fsROI.mat")


@utils.synchronized
@functools.lru_cache(maxsize=5)
def load_vset(subject_name: str) -> dict:
    return scipy.io.loadmat(f"{get_data_dir()}/MRI/fsROI/{subject_name}/vset_099.mat")


@utils.synchronized
@functools.lru_cache(maxsize=1)
def load_experiment_order() -> tuple[list[str], list[str]]:
    """Load clip names of train and validation split (in order of fMRI data)."""
    stimname_trn = scipy.io.loadmat(f"{get_data_dir()}/stim/ExpTrnOrder.mat")["TrnOrd"]
    stimname_val = scipy.io.loadmat(f"{get_data_dir()}/stim/ExpValOrder.mat")["ValOrd"]

    nmusic_trn, nsession_trn = stimname_trn.shape
    nmusic_val, nsession_val = stimname_val.shape

    stimname_trn_flat = []
    for csession in range(nsession_trn):
        for cmusic in range(nmusic_trn):
            stimname_trn_flat.append(stimname_trn[cmusic, csession][0])

    stimname_val_flat = []
    for csession in range(nsession_val):
        # Test clip was repeated 4 times in a single session (10 x 4 clips in a single session)
        for cmusic in range(nmusic_val // 4):
            stimname_val_flat.append(stimname_val[cmusic, csession][0])

    print(f"{len(stimname_trn_flat)=}, {len(stimname_val_flat)=}")

    # Assert that train and validation stimuli are disjoint.
    assert len(set(stimname_trn_flat) & set(stimname_val_flat)) == 0

    return stimname_trn_flat, stimname_val_flat


def get_trn_clip_names() -> list[str]:
    """Returns the clip names of the GTZAN training split."""
    return load_experiment_order()[0]


def get_val_clip_names() -> list[str]:
    """Returns the clip names of the GTZAN validation split."""
    return load_experiment_order()[1]


def export_predictions(
    file_name: str,
    emb_name: str,
    gtzan_slice_names: list[str],
    gtzan_preds: np.ndarray,
    gtzan_clip_names_trn: list[str],
    fma_clip_names_trn: list[str],
    gtzan_clip_names_val: list[str],
    fma_clip_names_val: list[str],
) -> None:
    if not len(gtzan_slice_names) == gtzan_preds.shape[0]:
        raise ValueError(
            f"Valid predictions must have one vector per stimulus. "
            f"Got #{len(gtzan_slice_names)} names and predictions of shape {gtzan_preds.shape}."
        )

    if not len(gtzan_clip_names_trn) == len(fma_clip_names_trn):
        raise ValueError(
            "Must have one FMA clip name per GTZAN clip name.",
            f"Got #{len(gtzan_clip_names_trn)} GTZAN names and #{len(fma_clip_names_trn)} FMA names.",
        )

    if not len(gtzan_clip_names_val) == len(fma_clip_names_val):
        raise ValueError(
            "Must have one FMA clip name per GTZAN clip name.",
            f"Got #{len(gtzan_clip_names_val)} GTZAN names and #{len(fma_clip_names_val)} FMA names.",
        )

    result_path = f"{get_data_dir()}/pred/{file_name}"
    with open(result_path, mode="wb") as f:
        np.savez(
            f,
            gtzan_slice_names=np.asarray(gtzan_slice_names),
            gtzan_preds=gtzan_preds,
            emb_name=emb_name,
            gtzan_clip_names_trn=gtzan_clip_names_trn,
            fma_clip_names_trn=fma_clip_names_trn,
            gtzan_clip_names_val=gtzan_clip_names_val,
            fma_clip_names_val=fma_clip_names_val,
        )
    print(f"Saved #{len(gtzan_slice_names)} predictions to {result_path}")
