import numpy as np
from sklearn.model_selection import train_test_split

def analyze_fold(X):

    X_min_max = np.zeros((X.shape[2], 2))
    X_mu_sigma = np.zeros((X.shape[2], 2))

    for i in range(X.shape[2]):
        Xt = X[:, :, i]
        X_min_max[i, 0] = Xt.min()
        X_min_max[i, 1] = Xt.max()
        X_mu_sigma[i, 0] = Xt.mean()
        X_mu_sigma[i, 1] = Xt.std()

    return X_min_max, X_mu_sigma



def get_kfold_stats(X, Y, train_fraction, valid_fraction, nfolds, stratify):
    N = X.shape[0]
    data_stats_dict = {}
    data_stats_dict['nfolds'] = nfolds

    training_folds = []
    validation_folds = []
    test_folds = []


    min_max_stats = []
    mu_sigma_stats = []

    series_indices = np.linspace(0, N - 1, N, endpoint=True, dtype='int')
    for rs in range(nfolds):
        np.random.seed(rs)
        if stratify:
            X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=(1 - train_fraction), stratify=Y, random_state=rs)
            X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=valid_fraction, stratify=y_train, random_state=rs)
        else:
            X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=(1 - train_fraction), stratify=None, random_state=rs)
            X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=valid_fraction, stratify=None, random_state=rs)


        min_max, mu_sigma = analyze_fold(X_train)

        training_folds.append([X_train, y_train])
        validation_folds.append([X_val, y_val])
        test_folds.append([X_test, y_test])

        min_max_stats.append(min_max)
        mu_sigma_stats.append(mu_sigma)

    data_stats_dict['training_folds'] = training_folds
    data_stats_dict['validation_folds'] = validation_folds
    data_stats_dict['test_folds'] = test_folds
    data_stats_dict['min_max'] = min_max_stats
    data_stats_dict['mu_sigma'] = mu_sigma_stats


    return data_stats_dict



def analyze(X, Y, train_fraction, valid_fraction, nfolds, stratify=False):

    data_stats_dict = get_kfold_stats(X, Y, train_fraction, valid_fraction, nfolds, stratify)

    return data_stats_dict