from pathlib import Path
from argparse import ArgumentParser
from torchvision.io import read_image
from torchvision.utils import save_image

def split_grid(grid, image_res = 256, padding_width = 2):
    imgs = []
    n_cols = int(grid.shape[2] / (image_res + padding_width))
    n_rows = int(grid.shape[1] / (image_res + padding_width))
    for row_id in range(n_rows):
        for col_id in range(n_cols):
            idx_0 = padding_width + (image_res + padding_width) * row_id
            idx_1 = padding_width + (image_res + padding_width) * col_id
            img = grid[:, idx_0:idx_0 + image_res, idx_1:idx_1 + image_res]
            imgs.append(img)
    return imgs

def main():
    parser = ArgumentParser()
    parser.add_argument("path_img")
    args = parser.parse_args()

    path_img = Path(args.path_img)
    grid = read_image(str(path_img)) / 255

    for idx, img in enumerate(split_grid(grid)):
        save_image(img, f"tmp_{idx}.png")


if __name__ == "__main__":
    main()