import argparse
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Callable, List, Optional

import albumentations
import numpy as np
from PIL import Image
from tqdm import tqdm


def wrapper(
    args: argparse.Namespace,
    loader: Callable,
    converter: callable,
    multiprocessing: bool = True,
    takes_dir: bool = False,
):
    _info_msg(args.input, args.output, args.amount, args.start)
    args.output.mkdir(exist_ok=True, parents=True)

    if args.input.is_dir() and not takes_dir:
        files = sorted(list(args.input.iterdir()))
    else:
        files = [args.input]

    count = 0
    pbar = tqdm(total=args.amount)
    for file in files:
        pbar.set_description(str(file.name))
        iterable = loader(file, args)
        if iterable is not None:
            if multiprocessing and args.num_workers is not None:
                with ProcessPoolExecutor(max_workers=args.num_workers) as pool:
                    for img in iterable:
                        if count < args.start:
                            count += 1
                            continue
                        filename = f"{img.stem}.png" if isinstance(img, Path) else f"{count:06}.png"
                        future = pool.submit(converter, img, args.output / filename, args)
                        future.add_done_callback(lambda _: pbar.update())
                        future.add_done_callback(lambda _: print(_.exception()) if _.exception() else None)
                        count += 1
                        if args.amount is not None and count == args.amount + args.start:
                            break
            else:
                for img in iterable:
                    if count < args.start:
                        count += 1
                        continue
                    converter(img, args.output / f"{count:06}.png", args)
                    pbar.update()
                    count += 1

                    if args.amount is not None and count == args.amount + args.start:
                        break

        if args.amount is not None and count == args.amount + args.start:
            break

    _check_amount(args.amount, count - args.start)


def convert_lmdb(args):
    import lmdb  # lazy import

    _info_msg(args.input, args.output, args.amount, args.start)
    args.output.mkdir(parents=True, exist_ok=True)

    env = lmdb.open(str(args.input), readonly=True)
    with env.begin() as txn:
        cursor = txn.cursor()
        count = 0
        pbar = tqdm(total=args.amount)
        for key, val in cursor:
            if count < args.start:
                count += 1
                continue
            img = Image.open(BytesIO(val))
            w, h = img.size
            scale = args.size / min(w, h)
            if scale != 1:  # from https://github.com/openai/guided-diffusion/blob/main/datasets/lsun_bedroom.py
                img = img.resize(
                    (int(round(scale * w)), int(round(scale * h))),
                    resample=Image.BOX,
                )
                w, h = img.size
            center_x, center_y = w // 2, h // 2
            img = img.crop(
                box=(
                    center_x - args.size // 2,
                    center_y - args.size // 2,
                    center_x + args.size // 2,
                    center_y + args.size // 2,
                )
            )
            img.save(args.output / f"{count:06}.png")
            count += 1
            pbar.update()
            if args.amount is not None and count == args.amount + args.start:
                break
    if scale != 1:
        print("Resizing was used.")

    _check_amount(args.amount, count - args.start)


def _array_to_png(array, path, args):

    if array.shape[0] == 3:  # CHW -> HWC
        array = np.transpose(array, (1, 2, 0))

    if np.issubdtype(array.dtype, np.floating):  # 0-1 -> 0-255
        array = (array * 255).astype(np.uint8)

    Image.fromarray(array).save(path)


def _dir_loader(path: Path, args: argparse.Namespace) -> Optional[List[Path]]:
    if path.is_dir():
        return list(sorted(path.iterdir()))


def _npz_loader(path: Path, args: argparse.Namespace) -> Optional[np.ndarray]:
    if path.suffix == ".npz":
        return np.load(path)[args.key]


def _dhariwal2021_pp(img, path, args):
    """Adapted from guided_diffusion/image_datasets.py."""
    image_size = args.size
    pil_image = Image.open(img).convert("RGB")
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2

    arr = arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]

    _array_to_png(arr, path, args)


def _rombach2022_pp(img, path, args):
    """Adapted from https://github.com/CompVis/taming-transformers/blob/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/data/base.py#L23."""
    rescaler = albumentations.SmallestMaxSize(max_size=args.size)
    cropper = albumentations.CenterCrop(height=args.size, width=args.size)
    preprocessor = albumentations.Compose([rescaler, cropper])

    image = Image.open(img)
    if not image.mode == "RGB":
        image = image.convert("RGB")
    image = np.array(image).astype(np.uint8)
    image = preprocessor(image=image)["image"]

    _array_to_png(image, path, args)


def _info_msg(input, output, amount, start):
    print(
        f"Extracting {'all' if amount is None else amount} images from {input} to {output} starting from image {start}."
    )


def _check_amount(amount, count):
    if amount is not None and count < amount:
        raise Warning(f"{count} images were converted, which is less than the specified {amount}.")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("input", type=Path, help="Path to input file/directory.")
    s_root = parser.add_subparsers()

    p_convert = s_root.add_parser("convert", help="Convert different data types to .png images.")
    p_convert.add_argument("output", type=Path, help="Path to output directory.")
    p_convert.add_argument("--start", type=int, default=0, help="Starting index.")
    p_convert.add_argument("--amount", type=int, help="Number of images to convert (default: all).")
    p_convert.add_argument("--num_workers", type=int, help="Number of worker processes to use (if applicable).")
    s_convert = p_convert.add_subparsers()

    p_npz = s_convert.add_parser("npz", help="Convert a single or a directory of .npz files.")
    p_npz.add_argument("--key", help="Key of array to extract (default: arr_0).", default="arr_0")
    p_npz.set_defaults(func=partial(wrapper, loader=_npz_loader, converter=_array_to_png))

    p_lmdb = s_convert.add_parser("lmdb", help="Convert a LMDB database.")
    p_lmdb.add_argument("--size", type=int, default=256)
    p_lmdb.set_defaults(func=convert_lmdb)

    p_preprocess = s_root.add_parser("preprocess", help="Perform preprocessing of a certain paper.")
    p_preprocess.add_argument("output", type=Path, help="Path to output directory.")
    p_preprocess.add_argument("--size", type=int, default=256, help="Desired output size (default: 256).")
    p_preprocess.add_argument("--start", type=int, default=0, help="Starting index.")
    p_preprocess.add_argument("--amount", type=int, help="Number of images to preprocess (default: all).")
    p_preprocess.add_argument(
        "--num_workers", default=4, type=int, help="Number of worker processes to use (if applicable)."
    )
    s_preprocess = p_preprocess.add_subparsers()

    p_dhariwal2021 = s_preprocess.add_parser("dhariwal2021", help="Perform preprocessing according to Dhariwal2021.")
    p_dhariwal2021.set_defaults(func=partial(wrapper, loader=_dir_loader, converter=_dhariwal2021_pp, takes_dir=True))

    p_rombach2022 = s_preprocess.add_parser("rombach2022", help="Perform preprocessing according to Rombach2022.")
    p_rombach2022.set_defaults(func=partial(wrapper, loader=_dir_loader, converter=_rombach2022_pp, takes_dir=True))

    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    parse_args()
