"""Functions for training a model from music embeddings to brain activity."""

import dataclasses
import os
from typing import Callable, Generic, TypeVar

import matplotlib.pyplot as plt
import numpy as np
import torch
from himalaya.backend import set_backend
from himalaya.kernel_ridge import (
    ColumnKernelizer,
    Kernelizer,
    KernelRidgeCV,
    MultipleKernelRidgeCV,
)
from himalaya.ridge import BandedRidgeCV, ColumnTransformerNoStack, RidgeCV
from himalaya.scoring import correlation_score, correlation_score_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from statsmodels.stats.multitest import fdrcorrection

from fmri2music import fmri_loader
from fmri2music.training import align_resp_to_stim, load_stimulus

T = TypeVar("T")
U = TypeVar("U")


@dataclasses.dataclass
class TrnVal(Generic[T]):
    """Tuple of something existing for training and evaluation data.

    This may be a trained model, a path, or data, for example."""

    trn: T
    val: T

    def map_fn(self, f: Callable[[T], U]) -> "TrnVal[U]":
        """Applies f to trn and val."""
        return TrnVal(trn=f(self.trn), val=f(self.val))


def mono_regressor(
    stim: dict[str, TrnVal[np.ndarray]], resp: TrnVal[np.ndarray], device: int
) -> tuple[dict[str, np.ndarray], object]:
    """Train an encoder for mono feature space."""
    alphas = np.logspace(-12, 12, 25)
    cv = 5
    for name, vec in stim.items():
        stim_name = name
        stim = vec

    x_trn, x_val = stim.trn.astype("float32"), stim.val.astype("float32")
    y_trn, y_val = resp.trn.astype("float32"), resp.val.astype("float32")

    n_samples_trn, n_features = x_trn.shape
    n_samples_val = x_val.shape[0]

    if device >= 0:
        if device < torch.cuda.device_count():
            os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            backend = set_backend("torch_cuda", on_error="warn")
            print("Running on GPU...")

        else:
            print("The CUDA device you specified is not available.")
            print("Running on CPU...")
    else:
        backend = set_backend("torch", on_error="warn")
        print("Running on CPU...")

    if n_samples_trn >= n_features:
        print("Solving Ridge regression...")
        ridge = RidgeCV(
            alphas=alphas, cv=cv, solver_params={"score_func": correlation_score}
        )
        pipeline = make_pipeline(StandardScaler(with_mean=True, with_std=False), ridge)

    else:
        print("Solving Kernel Ridge regression...")

        ridge = KernelRidgeCV(
            alphas=alphas, cv=cv, solver_params={"score_func": correlation_score}
        )
        pipeline = make_pipeline(StandardScaler(with_mean=True, with_std=False), ridge)

    pipeline.fit(x_trn, y_trn)
    y_val_pred = pipeline.predict(x_val)
    score = correlation_score(y_val_pred, y_val)
    score = backend.to_numpy(score)
    print(f"Mean score: {score.mean()}")
    significant_voxels = fdr_correction(score, n_samples_val)
    score[significant_voxels == False] = 0

    fig = plt.figure()
    plt.hist(score, np.linspace(0, np.max(score), 100), alpha=1.0, label=stim_name)
    plt.title(r"Histogram of correlation coefficient score")
    plt.legend()

    return score, fig


def multiple_regressor(
    stims: dict[str, TrnVal[np.ndarray]],
    resp: TrnVal[np.ndarray],
    n_iter: int,
    device: int,
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], object]:
    """Train an encoder for multiple feature space."""
    backend = set_backend("torch_cuda", on_error="warn")
    alphas = np.logspace(-12, 12, 25)
    cv = 5
    n_targets_batch = 1000
    n_alphas_batch = 1000
    n_targets_batch_refit = 1000
    solver = "random_search"

    solver_params = dict(
        n_iter=n_iter,
        alphas=alphas,
        n_targets_batch=n_targets_batch,
        n_alphas_batch=n_alphas_batch,
        n_targets_batch_refit=n_targets_batch_refit,
    )

    if device >= 0:
        if device < torch.cuda.device_count():
            os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            backend = set_backend("torch_cuda", on_error="warn")
            print("Running on GPU...")

        else:
            print("The CUDA device you specified is not available.")
            print("Running on CPU...")
    else:
        backend = set_backend("torch", on_error="warn")
        print("Running on CPU...")

    y_trn, y_val = resp.trn, resp.val
    y_trn, y_val = backend.asarray(y_trn), backend.asarray(y_val)

    xs_trn = []
    xs_val = []
    n_samples_trn = y_trn.shape[0]
    n_features = 0
    n_features_list = []
    for i, (model_name, stim) in enumerate(stims.items()):
        x_trn, x_val = stim.trn, stim.val
        print(f"Shapes of {model_name}: trn={x_trn.shape}, val={x_val.shape}")
        n_features += x_trn.shape[1]
        n_features_list.append(x_trn.shape[1])
        x_trn = x_trn.astype("float32")
        x_val = x_val.astype("float32")
        xs_trn.append(x_trn)
        xs_val.append(x_val)
    x_trn = np.concatenate(xs_trn, 1)
    x_val = np.concatenate(xs_val, 1)

    start_and_end = np.concatenate([[0], np.cumsum(n_features_list)])
    slices = [
        slice(start, end) for start, end in zip(start_and_end[:-1], start_and_end[1:])
    ]
    print(slices)

    if n_features > n_samples_trn:
        print("Solving Multiple Kernel Ridge regression using random search...")
        ridge = MultipleKernelRidgeCV(
            kernels="precomputed", solver=solver, solver_params=solver_params, cv=cv
        )
        preprocess_pipeline = make_pipeline(
            StandardScaler(with_mean=True, with_std=False), Kernelizer(kernel="linear")
        )
        kernelizers_tuples = [
            (name, preprocess_pipeline, slice_)
            for name, slice_ in zip(stims.keys(), slices)
        ]
        column_kernelizer = ColumnKernelizer(kernelizers_tuples)
        pipeline = make_pipeline(
            column_kernelizer,
            ridge,
        )
    else:
        print("Solving Banded Ridge regression using random search...")
        ridge = BandedRidgeCV(
            groups="input", solver=solver, solver_params=solver_params, cv=cv
        )
        preprocess_pipeline = make_pipeline(
            StandardScaler(with_mean=True, with_std=False),
        )
        ct_tuples = [
            (name, preprocess_pipeline, slice_)
            for name, slice_ in zip(stims.keys(), slices)
        ]

        column_transform = ColumnTransformerNoStack(ct_tuples)
        pipeline = make_pipeline(
            column_transform,
            ridge,
        )

    pipeline.fit(x_trn, y_trn)
    y_val_pred = pipeline.predict(x_val)
    y_val_pred = backend.to_numpy(y_val_pred)
    scores = pipeline.score(x_val, y_val)
    scores = backend.to_numpy(scores)
    print(f"Mean score: {scores.mean()}")

    y_val_pred_split = pipeline.predict(x_val, split=True)
    y_val_pred_split = backend.to_numpy(y_val_pred_split)
    split_scores = correlation_score_split(y_val, y_val_pred_split)
    split_scores = backend.to_numpy(split_scores)
    print("n_features_space, n_samples_test, n_voxels", y_val_pred_split.shape)
    print(split_scores[0].mean(), split_scores[1].mean())

    all_scores = {}
    all_preds = {}
    n_samples_val = stim.val.shape[0]

    print("all")
    significant_voxels = fdr_correction(scores, n_samples_val)
    scores[significant_voxels == False] = 0
    all_scores["all"] = scores
    all_preds["all"] = y_val_pred

    fig = plt.figure()
    for i, (score, pred, model_name) in enumerate(
        zip(split_scores, y_val_pred_split, stims.keys())
    ):
        print(model_name)
        plt.hist(
            score,
            np.linspace(0, np.max(split_scores), 100),
            alpha=0.8,
            label=model_name,
        )
        significant_voxels = fdr_correction(score, n_samples_val)
        score[significant_voxels == False] = 0
        all_scores[model_name] = score
        all_preds[model_name] = pred

    plt.title(r"Histogram of correlation coefficient score split between kernels")
    plt.legend()

    return all_scores, all_preds, fig


def fdr_correction(ccs: np.ndarray, valclipnum: int) -> np.ndarray:
    # Make random correlation coefficient histogram
    nvoxels = len(ccs)
    rccs = []

    # Max num of cortex voxels = 400 x 400
    a = np.random.randn(400, valclipnum)
    b = np.random.randn(400, valclipnum)
    rccs = np.corrcoef(a, b)
    rccs = rccs[400:, :400].ravel()
    rccs = rccs[:nvoxels]

    px = []
    for i in range(nvoxels):
        x = np.argwhere(rccs > ccs[i])
        px.append(len(x) / nvoxels)

    significant_voxels = fdrcorrection(px, alpha=0.05, method="indep", is_sorted=False)
    print(
        f"Number of voxels with significant positive correlation: {len(np.where(significant_voxels[0])[0])}"
    )

    return significant_voxels[0]


def train_predictor(
    name: str,
    reg_type: str,
    n_iter: int,
    emb_name: list[str],
    subject_name: str,
    device: int,
) -> None:
    """Runs a training session and returns GTZAN predictor object."""
    print("Loading stimulus data...")

    if type(emb_name) == str:  # Make list if the number of model is one
        emb_name = list(emb_name)

    all_stim = {}
    for ename in emb_name:
        print(f"Loading {ename}'s embeddings...")
        stim, stim_names = load_stimulus(ename)
        all_stim[ename] = stim

    print("Loading response data...")
    resp_raw = fmri_loader.load_resp_data(subject_name)
    resp = align_resp_to_stim(
        TrnVal(trn=resp_raw[0], val=resp_raw[1]), stim_names, delay=2
    )

    print("Training...")

    scores_dir = f"./data/encoding/{subject_name}/scores"
    os.makedirs(scores_dir, exist_ok=True)

    if len(all_stim) == 1:
        assert (
            reg_type == "mono"
        ), "When using Multi Mode, specify two or more model features."
        scores, fig = mono_regressor(all_stim, resp, device)
        np.save(f"{scores_dir}/cc_{name}.npy", scores)
        fig.savefig(f"{scores_dir}/dist_{name}.png")

    else:
        assert (
            reg_type == "multi"
        ), "When using Mono Mode, specify only one model features."
        all_scores, all_preds, fig = multiple_regressor(all_stim, resp, n_iter, device)
        np.save(f"{scores_dir}/cc_{name}.npy", all_scores)
        np.save(f"{scores_dir}/preds_{name}.npy", all_preds)
        fig.savefig(f"{scores_dir}/dist_{name}.png")
