from typing import overload
import numpy as np
import torch


@overload
def train_val_test_split(
    *ts: torch.Tensor,
    train_size: float,
    test_size: float,
    return_index: bool
) -> tuple[torch.Tensor, ...]:
    ...


@overload
def train_val_test_split(
    *ts: np.ndarray,
    train_size: float,
    test_size: float,
    return_index: bool
) -> tuple[np.ndarray, ...]:
    ...


def train_val_test_split(
    *ts, train_size, test_size, return_index = False,
    rs: np.random.RandomState = None
):
    """Split given tensors in train/val/test and return the splitting indexes
    along with each tensor divided, in the following format:
    idx_train, idx_val, idx_test, tensor1_train, tensor1_val, tensor1_test, ...
    """
    assert len(ts), 'No tensor given'
    assert len(set(map(len, ts))) == 1, 'Incongruent tensor sizes'

    if rs is None:
        rs = np.random.RandomState()

    N = len(ts[0])

    # Train-val splitting.
    # We don't need test, so use all data
    idx = rs.permutation(N)

    # Train-val-test split
    Ntrain = int(N * train_size)
    Ntest = int(N * test_size)

    idx_train = idx[:Ntrain]
    idx_val = idx[Ntrain:-Ntest]
    idx_test = idx[-Ntest:]

    t = [idx_train, idx_val, idx_test] if return_index else []

    for X in ts:
        t += [X[idx_train], X[idx_val], X[idx_test]]
    
    return tuple(t)
