
from math import floor
from math import ceil

import numpy as np

rng = np.random.default_rng(42)


def load_data(height=12, width=12, line_len=5, train_size=0.75, size=5000):
    
    assert line_len <= min(height, width)
    
    
    
    
    
    n_classes = 4
    per_class, remainder = divmod(size, n_classes)
    x = []
    y = []
    for c in range(n_classes):
        has_remainder = bool(remainder)
        size_class = per_class + int(has_remainder)
        if has_remainder:
            remainder -= 1
        
        x_class = np.zeros((size_class, height, width, 1), dtype=np.float32)
        
        if c != 0:  
            padf = (line_len - 1) / 2
            pad1, pad2 = floor(padf), ceil(padf)

            for x_i in x_class:
                if c == 1:  
                    i = rng.choice(range(height))
                    j = rng.choice(range(pad1, width - pad2))
                    x_i[i, j - pad1:j + pad2 + 1] = 1
                elif c == 2:  
                    i = rng.choice(range(pad1, height - pad2))
                    j = rng.choice(range(width))
                    x_i[i - pad1:i + pad2 + 1, j] = 1
                elif c == 3:  
                    i = rng.choice(range(pad1, height - pad2))
                    j = rng.choice(range(pad1, width - pad2))
                    i_idx = range(i - pad1, i + pad2 + 1)
                    if rng.choice(2):  
                        i_idx = [*reversed(i_idx)]
                    j_idx = range(j - pad1, j + pad2 + 1)
                    x_i[i_idx, j_idx] = 1
                else:
                    raise NotImplementedError(c)
        x.append(x_class)
        y_c = [0] * n_classes
        y_c[c] = 1
        y.extend([y_c] * size_class)

    x = np.concatenate(x, axis=0)
    y = np.asarray(y)

    sep_index = int(train_size * size)
    X_train, y_train = x[:sep_index], y[:sep_index]
    X_val, y_val = x[sep_index:], y[sep_index:]

    return (X_train, y_train), (X_val, y_val)


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    (X, Y), _ = load_data(
        height=12, width=12, line_len=5, train_size=1., size=4)
    f, axes = plt.subplots(2, 2)
    for c_, (x_, ax) in enumerate(zip(X, axes.flat)):
        ax.imshow(x_, cmap='gray')
        ax.set_title(str(c_))
    plt.show()
