import os

import cortex
import numpy as np

from fmri2music.training_encoder import prepare_roi


def brain_mapping(
    subject_name: str,
    xfm_name: str,
    tvoxels: list[np.ndarray],
    scores: np.ndarray,
    cmap: str,
) -> None:
    """Mapping scores to brain surface."""
    temp_volume = cortex.Volume.random(subject_name, xfm_name)
    shape = list(temp_volume.shape)
    data = np.zeros(shape).flatten()

    data[tvoxels] = scores[:, None]
    data[data == 0] = 0
    print(f"Max cc: {data.max()}")
    data = data.reshape(shape)

    volume = cortex.Volume(data, subject_name, xfm_name, vmin=0.0, vmax=0.88, cmap=cmap)

    return volume


def compare_mapping(
    subject_name: str,
    xfm_name: str,
    tvoxels: list[np.ndarray],
    all_scores: np.ndarray,
):
    RGB = [
        [255, 255, 0],  # "Yellow"
        [0, 255, 255],  # "Light Blue"
        [255, 0, 0],  # "Red"
        [0, 128, 128],  # "Green"
        [128, 0, 128],  # "Purple"
    ]

    temp_volume = cortex.Volume.random(subject_name, xfm_name)
    shape = list(temp_volume.shape)
    best_score = np.zeros(len(tvoxels)).flatten()
    best_model = np.zeros(len(tvoxels), dtype=int)
    for i, model_name in enumerate(all_scores.keys()):
        scores = all_scores[model_name]
        best_model[scores > best_score] = i + 1
        best_score[scores > best_score] = scores[scores > best_score]

    red = np.zeros(shape).flatten()
    green = np.zeros(shape).flatten()
    blue = np.zeros(shape).flatten()
    for i, model_name in enumerate(all_scores.keys()):
        best_voxels = tvoxels[best_model == (i + 1)]
        red[best_voxels[:, 0]] = RGB[i][0]
        green[best_voxels[:, 0]] = RGB[i][1]
        blue[best_voxels[:, 0]] = RGB[i][2]

    red = red.reshape(shape)
    green = green.reshape(shape)
    blue = blue.reshape(shape)

    mask = np.ones_like(red)
    mask_index = red + green + blue
    mask[mask_index == 0] = 0

    volume = cortex.dataset.VolumeRGB(
        red,
        green,
        blue,
        subject_name,
        xfm_name,
        vmin=0,
        vmax=1,
        alpha=mask,
    )

    return volume


def mapper(
    subject_name: str,
    xfm_name: str,
    model_names: list[str],
):
    _, tvoxels = prepare_roi(subject_name)

    map_dir = f"./data/encoding/{subject_name}/scores"
    if os.path.exists(map_dir) != True:
        os.makedirs(map_dir)

    if len(model_names) == 1:  # Make list if the number of model is one
        print(f"Loading {model_names}'s embeddings...")
        scores = np.load(
            f"./data/encoding/{subject_name}/scores/cc_{model_names[0]}.npy"
        )
        volume = brain_mapping(subject_name, xfm_name, tvoxels, scores, "afmhot")
        print("Mapping scores to brain surface...")
        cortex.quickflat.make_png(
            f"{map_dir}/{model_names[0]}.png",
            volume,
            recache=False,
            with_colorbar=True,
            with_labels=False,
            with_curvature=True,
        )

    else:
        all_scores = {}
        for name in model_names:
            print(name)
            print(f"Loading {name}'s embeddings...")
            scores = np.load(f"./data/encoding/{subject_name}/scores/cc_{name}.npy")
            all_scores[name] = scores
        volume = compare_mapping(
            subject_name, f"{subject_name}_xfm", tvoxels, all_scores
        )

        print("Mapping scores to brain surface...")
        cortex.quickflat.make_png(
            f"{map_dir}/{model_names}.png",
            volume,
            recache=False,
            with_colorbar=False,
            with_labels=False,
            with_curvature=True,
        )
