import numpy as np

def simple_generator(batch_size,X,Y,shuffle=True):
    #Y_ix=np.arange(Y.shape[0])
    index=0
    y_shape = 1
    if len(Y.shape)>1:
        y_shape = Y.shape[1]
    while True:
        batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
        batch_y=np.zeros((batch_size,y_shape), dtype=np.float32)
        if shuffle:
            ind=np.random.randint(0,Y.shape[0],size=batch_size)#np.random.choice(Y_ix,size=batch_size,replace=False)
        else:
            ind=np.arange(index,index+batch_size)
            if index+2*batch_size>Y.shape[0]:
                index=0
            else:
                index=index+batch_size
        batch_x[:]=X[ind]
        batch_y[:]=Y[ind]
        
        yield  batch_x, batch_y
        
def simple_dataset_generator(batch_size,X,Y,shuffle=True):
    #Y_ix=np.arange(Y.shape[0])
    def simple_dataset_generator_():
        index=0
        y_shape = 1
        if len(Y.shape)>1:
            y_shape = Y.shape[1]
        while True:
            batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
            batch_y=np.zeros((batch_size,y_shape), dtype=np.float32)
            if shuffle:
                ind=np.random.randint(0,Y.shape[0],size=batch_size)#np.random.choice(Y_ix,size=batch_size,replace=False)
            else:
                ind=np.arange(index,index+batch_size)
                if index+2*batch_size>Y.shape[0]:
                    index=0
                else:
                    index=index+batch_size
            batch_x[:]=X[ind]
            batch_y[:]=Y[ind]

            yield  batch_x, batch_y
    return simple_dataset_generator_


def simple_fair_generator(batch_size, X, Y, S, balanced=True):
    # Y_ix=np.arange(Y.shape[0])
    index = 0
    S_ix = np.array([i for i in range(S.shape[0])])
    S_ix_0 = S_ix[S[:,0] == 0]
    S_ix_1 = S_ix[S[:,0] == 1]
    y_shape = 1
    if len(Y.shape)>1:
        y_shape = Y.shape[1]
    while True:
        batch_x = np.zeros(((batch_size,) + (X[0].shape)), dtype=np.float32)
        batch_y = np.zeros((batch_size, 2), dtype=np.float32)
        if balanced:
            ind_1 = np.random.choice(S_ix_0, size=batch_size//2,replace=False)
            ind_2 = np.random.choice(S_ix_1, size=batch_size- batch_size // 2, replace=False)
            ind = np.concatenate([ind_1,ind_2])
        else:
            ind = np.random.choice(S_ix, size=batch_size,replace=False)

        batch_x[:] = X[ind]
        batch_y[:,0] = Y[ind].squeeze()
        batch_y[:,1] = S[ind].squeeze()
        yield batch_x, batch_y

def otp_generator(batch_size,X,Y):

    Y_ix=np.array([i for i in range(Y.shape[0]) ])
    Y0_ix=Y_ix[Y==1]
    Y1_ix=Y_ix[Y==-1]
    half=Y.shape[0]//2
    while True:
        batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
        batch_y=np.zeros((batch_size,1), dtype=np.float32)
        ind=np.random.choice(Y0_ix,size=batch_size//2,replace=False)
        batch_x[:batch_size//2,]=X[ind]
        batch_y[:batch_size//2,0]=Y[ind]
        ind=np.random.choice(Y1_ix,size=batch_size//2,replace=False)
        batch_x[batch_size//2:,]=X[ind]
        batch_y[batch_size//2:,0]=Y[ind]
        
        yield  batch_x, batch_y
