from inits import *
import time


def shuffle_aligned_list(data):
    num = data[0].shape[0]
    p = np.random.permutation(num)
    return [d[p] for d in data]


def batch_generator(data, batch_size, unlab_data=None, num_lab_each_epoch=1, shuffle=True):
    if unlab_data is not None:
        num_lab = data[0].shape[0]
        lab_target = data
        unlab_target = unlab_data
        if shuffle:
            unlab_target = shuffle_aligned_list(unlab_target)
        batch_count = 0
        while True:
            if batch_count * (batch_size -num_lab_each_epoch) + (batch_size -num_lab_each_epoch) >= len(unlab_target[0]):
                batch_count = 0
                if shuffle:
                    unlab_target = shuffle_aligned_list(unlab_target)
            start = batch_count * (batch_size -num_lab_each_epoch)
            end = start + (batch_size -num_lab_each_epoch)

            lab_idx = np.random.choice(num_lab, num_lab_each_epoch, replace=False)
            batch_count += 1
            yield [np.concatenate([lab_target[i][lab_idx], unlab_target[i][start:end]], axis=0) for i in range(len(unlab_target))]
    else:
        if shuffle:
            data = shuffle_aligned_list(data)
        batch_count = 0
        while True:
            if batch_count * batch_size + batch_size >= len(data[0]):
                batch_count = 0

                if shuffle:
                    data = shuffle_aligned_list(data)

            start = batch_count * batch_size
            end = start + batch_size
            batch_count += 1
            yield [d[start:end] for d in data]

def batch_generator_target(data, batch_size, unlab_data=None, num_lab_each_epoch=1, shuffle=True):
    num_lab = data[0].shape[0]
    lab_target = data
    unlab_target = unlab_data
    if shuffle:
        unlab_target = shuffle_aligned_list(unlab_target)
    batch_count = 0
    while True:
        if batch_count * (batch_size - num_lab_each_epoch) + (batch_size - num_lab_each_epoch) >= len(unlab_target[0]):
            batch_count = 0
            if shuffle:
                unlab_target = shuffle_aligned_list(unlab_target)
        start = batch_count * (batch_size - num_lab_each_epoch)
        end = start + (batch_size - num_lab_each_epoch)

        lab_idx = np.random.choice(num_lab, num_lab_each_epoch, replace=False)
        batch_count += 1
        yield [np.concatenate([lab_target[i][lab_idx], unlab_target[i][start:end]], axis=0) for i in range(len(unlab_target))]


def shuffle_multiple_aligned_list(data):
    X_source, Y_source = data
    new_X_source, new_Y_source = [], []
    for i in range(len(data[0])):
        num = X_source[i].shape[0]
        p = np.random.permutation(num)
        new_X_source.append(X_source[i][p])
        new_Y_source.append(Y_source[i][p])
    return [new_X_source, new_Y_source]


def batch_generator_source(data, batch_size, shuffle=True):
    if shuffle:
        data = shuffle_multiple_aligned_list(data)
    batch_count = 0
    while True:
        if batch_count * batch_size + batch_size >= len(data[0][-1]):
            batch_count = 0

            if shuffle:
                data = shuffle_multiple_aligned_list(data)

        start = batch_count * batch_size
        end = start + batch_size
        batch_count += 1
        batch_X_data, batch_Y_data = [], []
        for i in range(len(data[0])):
            batch_X_data.append(data[0][i][start:end])
            batch_Y_data.append(data[1][i][start:end])
        batch_X_data, batch_Y_data = np.concatenate(batch_X_data, axis=0), np.concatenate(batch_Y_data, axis=0)
        yield [batch_X_data, batch_Y_data]
