import numpy as np
import math

def generate_toy_images(shape,frac=0,v=1):
    img = np.zeros(shape)
    if frac==0:
        return img
    frac=frac**0.5
    #print(frac)
    l=int(shape[0]*frac)
    if l<=0:
        l=1
    ldec=(shape[0]-l)//2
    w=int(shape[1]*frac)
    wdec=(shape[1]-w)//2
    img[ldec:ldec+l,wdec:wdec+w,:]=v
    print("Norm L2 "+str(np.linalg.norm(img)))
    return img

def ternary_generator(batch_size,shape,frac=0,gtValues=[0,1]):
    batch_x = np.zeros(((batch_size,)+(shape)), dtype=np.float16)
    batch_y=gtValues[0]*np.ones((batch_size,1), dtype=np.float16) #np.zeros((batch_size,1), dtype=np.float16)
    batch_x[3*batch_size//4:,]=generate_toy_images(shape,frac=frac,v=1)
    batch_x[batch_size//2:3*batch_size//4,]=generate_toy_images(shape,frac=frac,v=-1)
    batch_y[batch_size//2:]=gtValues[1]
    #indexes_shuffle = np.arange(batch_size)
    while True:
        #np.random.shuffle(indexes_shuffle)
        #yield  batch_x[indexes_shuffle,], batch_y[indexes_shuffle,]
        yield  batch_x, batch_y
        

def binary_generator(batch_size,shape,frac=0,gtValues=[0,1]):
    batch_x = np.zeros(((batch_size,)+(shape)), dtype=np.float16)
    print(batch_x.shape)
    batch_y=gtValues[0]*np.ones((batch_size,1), dtype=np.float16) #np.zeros((batch_size,1), dtype=np.float16)
    batch_x[batch_size//2:,]=generate_toy_images(shape,frac=frac,v=1)
    batch_y[batch_size//2:]=gtValues[1]
    while True:
        yield  batch_x, batch_y

def toy_binary_generator(batch_size,shape,frac=0,gtValues=[0,1]):
    dtset = {'train' : binary_generator(batch_size,tuple(shape),frac,gtValues=gtValues), 'trainSize': None ,
             'valid' : None, 'validSize': None, 
             'test' : None, 'testSize': None, 
             'batch_size': batch_size }
    return dtset
def toy_ternary_generator(batch_size,shape,frac=0):
    dtset = {'train' : ternary_generator(batch_size,tuple(shape),frac,gtValues=gtValues),  'trainSize': None ,
            'valid' : None,  'validSize': None,
            'test' : None, 'testSize': None, 
            'batch_size': batch_size }
    return dtset