
import torch
from sklearn.model_selection import KFold, StratifiedKFold


def imbalance_preserving_downsample(X: torch.tensor, y: torch.tensor, factor):
    if factor <= 1:
        raise Exception('Downsample factor should be larger than 1')

    def downsample(X: torch.tensor, factor):
        idx = int(X.shape[0] / factor)
        return X[:idx]

    X_min = X[y == 1]
    X_min_down = downsample(X_min, factor)
    y_min_down = torch.ones(X_min_down.shape[0])
    X_maj = X[y == 0]
    X_maj_down = downsample(X_maj, factor)
    y_maj_down = torch.zeros(X_maj_down.shape[0])

    X, y =  torch.cat((X_min_down, X_maj_down), 0), torch.cat((y_min_down, y_maj_down), 0)
    return X, y

class imbalance_preserving_Kfold():
    def __init__(self, n_splits=5):
        self.n_splits = n_splits

    def split(self, X: torch.tensor, y: torch.tensor):
        x_maj = X[y == 0]
        x_min = X[y == 1]
        x_maj_KFold = KFold(self.n_splits)
        x_min_KFold = KFold(self.n_splits)
        for (x_maj_train_idx, x_maj_val_idx), (x_min_train_idx, x_min_val_idx) in zip(x_maj_KFold.split(x_maj), x_min_KFold.split(x_min)):
            #x_maj_train_idx, x_maj_val_idx = x_maj_KFold.split(x_maj)
            x_maj_train, x_maj_val = x_maj[x_maj_train_idx], x_maj[x_maj_val_idx]
            y_maj_train, y_maj_val = torch.zeros(x_maj_train.shape[0]), torch.zeros(x_maj_val.shape[0])
            #x_min_train_idx, x_min_val_idx = x_min_KFold.split(x_min)
            x_min_train, x_min_val = x_min[x_min_train_idx], x_min[x_min_val_idx]
            y_min_train, y_min_val = torch.ones(x_min_train.shape[0]), torch.ones(x_min_val.shape[0])
            x_train = torch.cat((x_maj_train, x_min_train), 0)
            y_train = torch.cat((y_maj_train, y_min_train), 0)
            x_val = torch.cat((x_maj_val, x_min_val), 0)
            y_val = torch.cat((y_maj_val, y_min_val), 0)
            yield x_train, y_train, x_val, y_val