import argparse
import os.path
from argparse import ArgumentParser
from typing import Any

import numpy as np
from imageio import imwrite


def load_from_folder(folder: str, num_samples: int, start_index: int = 0) -> np.ndarray:
    return np.array([np.load(f'{folder}/{i + start_index}.npy') for i in range(num_samples)])


def run(
        image_folders: list[str],
        save_path: str,
        num_samples: int,
        rows: int,
        cols: int,
        start_index: list[int] = None
) -> None:
    assert num_samples == rows * cols, f'expected {rows * cols} samples, got {num_samples}'
    if start_index is None:
        start_index: list[int] = [0 for _ in range(len(image_folders))]

    samples: np.ndarray = np.array([
        load_from_folder(image_folders[i], num_samples, start_index[i]) for i in range(len(image_folders))])

    print(samples.shape)
    _, _, channels, height, width = samples.shape
    samples: np.ndarray = np.reshape(samples, (-1, channels, height, width))

    samples_img: np.ndarray = np.zeros((channels, height * rows, width * cols))
    for i in range(rows):
        for j in range(cols):
            samples_img[:, i * height:(i + 1) * height, j * width:(j + 1) * width] = samples[i * cols + j]

    samples_converted: np.ndarray = (np.clip(np.transpose(
        samples_img * 0.5 + 0.5, (1, 2, 0)), 0, 1) * 255).round().astype(np.uint8)

    if os.path.dirname(save_path) != '':
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    imwrite(save_path, samples_converted)


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: ArgumentParser = ArgumentParser()
    parser.add_argument('--image_folders', type=str, nargs='+', required=True)
    parser.add_argument('--save_path', type=str, required=True)
    parser.add_argument('--num_samples', type=int, default=50)
    parser.add_argument('--rows', type=int, required=True)
    parser.add_argument('--cols', type=int, required=True)
    parser.add_argument('--start_index', type=int, nargs='+', default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
