import numpy as np
#splits a concated time series dataset into each time series chunk. Then does kfold within that chunk

def chunk_Kfold(chunk_start_index, chunk_end_index, k, block_shuffle=True):
    all_indices = np.arange(chunk_start_index, chunk_end_index, dtype=int)
    block_size = np.floor((chunk_end_index - chunk_start_index)/k).astype(int)
    ks = np.arange(k)
    if block_shuffle:
        np.random.shuffle(ks)
    for split_index in ks:
        validation_start_index = block_size*split_index + chunk_start_index
        validation_end_index = validation_start_index + block_size
        validation_indices = np.arange(validation_start_index, validation_end_index, dtype=int)
        train_indices = all_indices[(all_indices < validation_start_index) | (all_indices > validation_end_index)] 
        yield train_indices, validation_indices

class ConcatedBlockKfold():
    def __init__(self, chunk_lens, n_splits = 5):
        self.n_splits = n_splits
        self.chunk_lens = chunk_lens

    def split(self, X=None,Y=None, block_shuffle=True):
        chunk_end_indices = np.cumsum(self.chunk_lens)
        chunk_start_indices = np.concat([np.zeros(1), chunk_end_indices])[:-1]
        chunk_Kfold_generators = [chunk_Kfold(chunk_start_indices[i], chunk_end_indices[i],self.n_splits,
                                                block_shuffle=block_shuffle) for i in range(len(self.chunk_lens))]
        for _ in range(self.n_splits):
            train_indices_list, validation_indices_list = list(zip(*[next(chunk_kfold_gen) for chunk_kfold_gen in chunk_Kfold_generators]))
            train_indices = np.concatenate(train_indices_list)
            validation_indices = np.concatenate(validation_indices_list)
            yield train_indices, validation_indices
            
class BlockKfold():
    def __init__(self, n_splits, block_shuffle=True):
        self.n_splits = n_splits
        self.block_shuffle = block_shuffle
        
    def split(self, dataset_len):
        all_indices = np.arange(0, dataset_len, dtype=int)
        block_size = np.floor(dataset_len/self.n_splits).astype(int)
        ks = np.arange(self.n_splits)
        if self.block_shuffle:
            np.random.shuffle(ks)
            
        for split_index in ks:
            validation_start_index = block_size*split_index
            validation_end_index = validation_start_index + block_size
            validation_indices = np.arange(validation_start_index, validation_end_index, dtype=int)
            train_indices = all_indices[(all_indices < validation_start_index) | (all_indices > validation_end_index)] 
            yield train_indices, validation_indices
        

class DatasetBlockKfold():
    def __init__(self, dataset_lens, n_splits = 5, block_shuffle=True):
        self.n_splits = n_splits
        self.dataset_lens = dataset_lens
        self.block_shuffle = block_shuffle
        kfold = BlockKfold(self.n_splits, block_shuffle=block_shuffle)
        self.kfold_splits = [list(kfold.split(dataset_len)) for dataset_len in dataset_lens]
        
    def split(self):
        for k in range(self.n_splits):
            yield list(zip(*[kfold_split[k] for kfold_split in self.kfold_splits]))