# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Extracting crops for pre-training
# --------------------------------------------------------

import os
import argparse
from tqdm import tqdm
from PIL import Image
import functools
from multiprocessing import Pool
import math


def arg_parser():
    parser = argparse.ArgumentParser(
        "Generate cropped image pairs from image crop list"
    )

    parser.add_argument("--crops", type=str, required=True, help="crop file")
    parser.add_argument("--root-dir", type=str, required=True, help="root directory")
    parser.add_argument(
        "--output-dir", type=str, required=True, help="output directory"
    )
    parser.add_argument("--imsize", type=int, default=256, help="size of the crops")
    parser.add_argument(
        "--nthread", type=int, required=True, help="number of simultaneous threads"
    )
    parser.add_argument(
        "--max-subdir-levels",
        type=int,
        default=5,
        help="maximum number of subdirectories",
    )
    parser.add_argument(
        "--ideal-number-pairs-in-dir",
        type=int,
        default=500,
        help="number of pairs stored in a dir",
    )
    return parser


def main(args):
    listing_path = os.path.join(args.output_dir, "listing.txt")

    print(f"Loading list of crops ... ({args.nthread} threads)")
    crops, num_crops_to_generate = load_crop_file(args.crops)

    print(f"Preparing jobs ({len(crops)} candidate image pairs)...")
    num_levels = min(
        math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)),
        args.max_subdir_levels,
    )
    num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels))

    jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
    del crops

    os.makedirs(args.output_dir, exist_ok=True)
    mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
    call = functools.partial(save_image_crops, args)

    print(f"Generating cropped images to {args.output_dir} ...")
    with open(listing_path, "w") as listing:
        listing.write("# pair_path\n")
        for results in tqdm(mmap(call, jobs), total=len(jobs)):
            for path in results:
                listing.write(f"{path}\n")
    print("Finished writing listing to", listing_path)


def load_crop_file(path):
    data = open(path).read().splitlines()
    pairs = []
    num_crops_to_generate = 0
    for line in tqdm(data):
        if line.startswith("#"):
            continue
        line = line.split(", ")
        if len(line) < 8:
            img1, img2, rotation = line
            pairs.append((img1, img2, int(rotation), []))
        else:
            l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
            rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
            pairs[-1][-1].append((rect1, rect2))
            num_crops_to_generate += 1
    return pairs, num_crops_to_generate


def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
    jobs = []
    powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]

    def get_path(idx):
        idx_array = []
        d = idx
        for level in range(num_levels - 1):
            idx_array.append(idx // powers[level])
            idx = idx % powers[level]
        idx_array.append(d)
        return "/".join(map(lambda x: hex(x)[2:], idx_array))

    idx = 0
    for pair_data in tqdm(pairs):
        img1, img2, rotation, crops = pair_data
        if -60 <= rotation and rotation <= 60:
            rotation = 0  # most likely not a true rotation
        paths = [get_path(idx + k) for k in range(len(crops))]
        idx += len(crops)
        jobs.append(((img1, img2), rotation, crops, paths))
    return jobs


def load_image(path):
    try:
        return Image.open(path).convert("RGB")
    except Exception as e:
        print("skipping", path, e)
        raise OSError()


def save_image_crops(args, data):
    # load images
    img_pair, rot, crops, paths = data
    try:
        img1, img2 = [
            load_image(os.path.join(args.root_dir, impath)) for impath in img_pair
        ]
    except OSError as e:
        return []

    def area(sz):
        return sz[0] * sz[1]

    tgt_size = (args.imsize, args.imsize)

    def prepare_crop(img, rect, rot=0):
        # actual crop
        img = img.crop(rect)

        # resize to desired size
        interp = (
            Image.Resampling.LANCZOS
            if area(img.size) > 4 * area(tgt_size)
            else Image.Resampling.BICUBIC
        )
        img = img.resize(tgt_size, resample=interp)

        # rotate the image
        rot90 = (round(rot / 90) % 4) * 90
        if rot90 == 90:
            img = img.transpose(Image.Transpose.ROTATE_90)
        elif rot90 == 180:
            img = img.transpose(Image.Transpose.ROTATE_180)
        elif rot90 == 270:
            img = img.transpose(Image.Transpose.ROTATE_270)
        return img

    results = []
    for (rect1, rect2), path in zip(crops, paths):
        crop1 = prepare_crop(img1, rect1)
        crop2 = prepare_crop(img2, rect2, rot)

        fullpath1 = os.path.join(args.output_dir, path + "_1.jpg")
        fullpath2 = os.path.join(args.output_dir, path + "_2.jpg")
        os.makedirs(os.path.dirname(fullpath1), exist_ok=True)

        assert not os.path.isfile(fullpath1), fullpath1
        assert not os.path.isfile(fullpath2), fullpath2
        crop1.save(fullpath1)
        crop2.save(fullpath2)
        results.append(path)

    return results


if __name__ == "__main__":
    args = arg_parser().parse_args()
    main(args)
