import numpy as np
import sklearn.model_selection as skl_model_selection
import torch as th
import torch.utils.data as th_data


def th_dataset_to_ndarrays(
    dataset: th_data.Dataset[tuple[th.Tensor, ...]]
) -> tuple[np.ndarray, ...]:
    assert hasattr(dataset, "__len__")
    tensors_ls = [list() for _ in range(len(dataset[0]))]
    for idx in range(len(dataset)):
        entry = dataset[idx]
        for i, t in enumerate(entry):
            tensors_ls[i].append(t)
    return tuple(th.stack(tl).numpy() for tl in tensors_ls)


def get_train_val_test_fold(
    xs: np.ndarray,
    ys: np.ndarray,
    train_fold_idx: int,
    n_folds: int,
    n_splits: int,
    test_size=0.125,
    is_stratified: tuple[bool, bool] = (False, False),
    rstates: tuple[int, int] = (42, 1337),
) -> tuple[list[int], list[int], list[int]]:
    assert train_fold_idx > 0 and train_fold_idx <= n_folds
    # (train+val), test
    kfold = (
        skl_model_selection.StratifiedKFold(
            n_splits=n_folds, shuffle=True, random_state=rstates[0]
        )
        if is_stratified[0]
        else skl_model_selection.KFold(
            n_splits=n_folds, shuffle=True, random_state=rstates[0]
        )
    )
    train_val_indices: np.ndarray = np.empty(1)
    test_indices: np.ndarray = np.empty(1)
    for train_val_indices, test_indices in kfold.split(xs, ys):
        if train_fold_idx != 1:
            train_fold_idx = train_fold_idx - 1
        else:
            break
    # train, val
    ssplit = (
        skl_model_selection.StratifiedShuffleSplit(
            n_splits=n_splits, test_size=test_size, random_state=rstates[1]
        )
        if is_stratified[1]
        else skl_model_selection.ShuffleSplit(
            n_splits=n_splits, test_size=test_size, random_state=rstates[1]
        )
    )
    train_indices_, val_indices_ = next(
        ssplit.split(xs[train_val_indices], ys[train_val_indices])
    )
    train_indices: np.ndarray = train_val_indices[train_indices_]
    val_indices: np.ndarray = train_val_indices[val_indices_]
    return (
        sorted(train_indices.tolist()),
        sorted(val_indices.tolist()),
        sorted(test_indices.tolist()),
    )


def get_train_val_fold(
    xs: np.ndarray,
    ys: np.ndarray,
    n_splits: int,
    val_size=0.125,
    rstate: int = 1337,
) -> tuple[list[int], list[int]]:
    # train, val
    ssplit = skl_model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, test_size=val_size, random_state=rstate
    )
    train_indices: np.ndarray
    val_indices: np.ndarray
    train_indices, val_indices = next(ssplit.split(xs, ys))
    return (sorted(train_indices.tolist()), sorted(val_indices.tolist()))
