import numpy as np
from sklearn.model_selection import train_test_split


"""
Create a sequence of T train and test datasets 
(D_train, D_test_)_1 , .... (D_train, D_test)_T
the first training dataset is at least |D_train_1| > min_train_n
all testing datasets have the same size |D_test| = test_n
"""


def generate_sequence_from_fixed_dataset(X, y, T, min_train_n, seed, test_frac, test_n, covariate_shift):

    if covariate_shift:
        min_x0 = min(X[:, 0])
        max_x0 = max(X[:, 0])

        def filter_cov(X, y, t):
            cutoff_t = (t/(2*T) * (max_x0-min_x0)) + min_x0
            filtered_index = np.argwhere(X[:, 0] >= cutoff_t)
            filterd_X = X[filtered_index]
            filtered_y = y[filtered_index]
            return filterd_X.squeeze(), filtered_y.squeeze()
    else:
        filter_cov = None
    # We first split the test from train to avoid any test/train contamination
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_frac, random_state=seed)

    seq_train_datasets = generate_data(
        X_train, y_train, T, seed, min_train_n, increasing=(not covariate_shift), allow_overlap=False, covariate_shift=covariate_shift, filter_cov=filter_cov)
    seq_test_datasets = generate_data(
        X_test, y_test, T, seed, test_n, increasing=False, allow_overlap=True, covariate_shift=covariate_shift, filter_cov=filter_cov)

    return seq_train_datasets, seq_test_datasets


"""
Create a sequence of T datasets of increasing size, with a min size of min_n_samples:
return [D_1, D_2, ... D_T]
with |D_1| >= min_n_samples
"""


def generate_data(X, y, T, seed, min_n_samples, increasing, allow_overlap, covariate_shift, filter_cov=None):
    N = X.shape[0]
    if increasing:
        if N/T < min_n_samples:  # if the steps are too small, we start with a big dataset
            D_t_size = (N-min_n_samples)/(T-1)
            frac_first_dataset = (N-min_n_samples)/N
        else:
            D_t_size = N/T
            frac_first_dataset = (N-D_t_size)/N
    else:
        D_t_size = min_n_samples
        frac_first_dataset = (N-min_n_samples)/N
    seq_datasets = []

    # append the first dataset
    X_1, X_rest, y_1, y_rest = train_test_split(
        X, y, test_size=frac_first_dataset, random_state=seed)
    seq_datasets.append((X_1, y_1))

    for t in range(T-1):  # append the T-1 other datasets

        seed_t = seed+t  # to ensure the seed changes per t
        if allow_overlap:
            X_split = X
            y_split = y
        else:
            X_split = X_rest
            y_split = y_rest

        # apply covariate filtering if we need to
        if covariate_shift:
            X_split, y_split = filter_cov(X_split, y_split, t)

        N_remaining = X_split.shape[0]
        rest_frac = (N_remaining-D_t_size)/N_remaining

        if rest_frac <= 0:  # nothing to split
            X_t = X_split
            y_t = y_split
        else:
            X_t, X_rest, y_t, y_rest = train_test_split(
                X_split, y_split, test_size=rest_frac, random_state=seed_t)
        if increasing:
            X_t = np.concatenate((seq_datasets[-1][0], X_t))
            y_t = np.concatenate((seq_datasets[-1][1], y_t))
        seq_datasets.append((X_t, y_t))

    return seq_datasets


def get_sequence(X, y, seed, cfg_dataset, cfg_exp):

    seq_datasets, seq_queries_datasets = generate_sequence_from_fixed_dataset(
        X, y, cfg_dataset['T'], cfg_dataset['min_train_n'], seed, cfg_dataset['test_frac'], cfg_dataset['test_n'], cfg_exp['covariate_shift'])
    # cut them
    offline_split = cfg_dataset["offline_t"]
    offline_seq = seq_datasets[:offline_split]
    online_seq = seq_datasets[offline_split:]
    offline_queries = seq_queries_datasets[offline_split:]
    online_queries = seq_queries_datasets[:offline_split]

    return offline_seq, online_seq, offline_queries, online_queries
