import os
import json
from PIL import Image

import torch
import torch.nn.functional as F


def gen_chess_array(
    grid_size: int,
    patch_size: int,
    num_examples: int,
    pad2size: int = -1,
    seed: int = 2024,
):
    torch.manual_seed(seed)

    grids = torch.randint(
        2,
        size=(
            num_examples,
            3,
            grid_size,
            grid_size,
        ),
        dtype=torch.float32,
    ) * 255.0
    chess = F.interpolate(
        grids,
        scale_factor=patch_size,
        mode="nearest",
    )

    img_size = grid_size * patch_size
    if pad2size > 0 and img_size < pad2size:
        background = torch.zeros(
            (
                num_examples,
                3,
                pad2size,
                pad2size,
            ),
            dtype=torch.float32,
        )
        background[..., :img_size, :img_size] = chess
        chess = background

    chess = chess.permute(0, 2, 3, 1)
    chess = chess.to(torch.uint8).numpy()
    return chess


def gen_chess_images(save_folder, **kwargs):

    default_config = {
        "grid_size": 5,
        "patch_size": 80,
        "num_examples": 100,
        "pad2size": 512,
        "seed": 2024,

    }
    for key, value in kwargs.items():
        if key in default_config:
            default_config[key] = value

    arrays = gen_chess_array(**default_config)

    root_dir = kwargs.get("root_dir", "images")
    os.makedirs(f"{root_dir}/", exist_ok=True)
    os.makedirs(f"{root_dir}/{save_folder}/", exist_ok=True)

    for idx, array in enumerate(arrays):
        path = f"{root_dir}/{save_folder}/{idx}.png"
        pil_image = Image.fromarray(array)
        pil_image.save(path)

    config_path = f"{root_dir}/{save_folder}/config.json"
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(default_config, f, indent=4)
