# A script to generate an "N-mixture" version of image dataset.
import os
import math
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
from multiprocessing import Pool, cpu_count

# Below is the absolute path of the input directory of datasets.
# A sample folder structure is:
# {PATH_DATASET_PREFIX}/
# ├── {KEY_CIFAR_10}/
# │   ├── test/
# │   │   ├── 0/
# │   │   │   ├── test_batch_1001.png
# │   │   │   ├── test_batch_1010.png
# │   │   │   └── ...
# │   │   ├── 1/
# │   │   └── ...
# │   └── train/
# │       └── ...
PATH_DATASET_PREFIX = f"/tmp/data"
KEY_CIFAR_10 = "cifar-100"
CPU_CORE = 50
cfg = {
    "patch_size": 4,
    "ratio_vn": 0.30,
    "ratio_e": 0.10,
    "should_borrow_vn_from_single_class": True,
    "output_folder_name": "cifar-100-nmix",
    "test_subset": 1.0,
    "num_classes": 100,
}

#######################################################################
# Utils
#######################################################################

# Description: split an image to a list of patches
# Input: np.ndarray of shape [3,32,32]
# Output: a list of np.ndarry [3,8,8], len=16
def get_patches(img, patch_size):
    patches = []
    for i in range(0, img.shape[1], patch_size):
        for j in range(0, img.shape[2], patch_size):
            patches.append(img[:, i:i+patch_size, j:j+patch_size])
    return patches

# Description: reconstruct a full image from non-overlapping patches
# Input: a list of np.ndarray [3,8,8]
# Output: np.ndarray of shape [3,32,32]
def combine_patches(patches, patch_size, img_shape):
    C, H, W = img_shape
    new_img = np.zeros((C, H, W), dtype=np.uint8)
    idx = 0
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            new_img[:, i:i+patch_size, j:j+patch_size] = patches[idx]
            idx += 1
    return new_img

#######################################################################
# Dataset generation - multi processes
#######################################################################

def process_one_image(img_path, cls, donor_pool, cfg, output_dir, seed):
    rng = np.random.default_rng(seed + hash(img_path) % (10**5))
    patch_size = cfg["patch_size"]
    ratio_vn = cfg["ratio_vn"]
    ratio_e = cfg["ratio_e"]

    img = np.array(Image.open(img_path)).transpose(2, 0, 1)
    patches = get_patches(img, patch_size)
    num_patches = len(patches)
    n_vn = math.floor(num_patches * ratio_vn)
    n_e = math.floor(num_patches * ratio_e)
    total_replace = n_vn + n_e

    replace_idxs = rng.choice(num_patches, size=total_replace, replace=False)
    vn_idxs = replace_idxs[:n_vn]
    e_idxs = replace_idxs[n_vn:]

    # Sample donor images
    donor_images = rng.choice(donor_pool, size=n_vn, replace=True)

    for idx, patch_idx in enumerate(vn_idxs):
        donor_img = np.array(Image.open(donor_images[idx])).transpose(2, 0, 1)
        donor_patch = get_patches(donor_img, patch_size)[rng.integers(64)]
        patches[patch_idx] = donor_patch
    for patch_idx in e_idxs:
        noise = rng.normal(loc=127, scale=40, size=(3, patch_size, patch_size))
        patches[patch_idx] = np.clip(noise, 0, 255).astype(np.uint8)

    new_img = combine_patches(patches, patch_size, img.shape)
    out_path = os.path.join(output_dir, str(cls), os.path.basename(img_path))
    Image.fromarray(new_img.transpose(1, 2, 0)).save(out_path)

def process_dataset(split, cfg, seed=42):
    path_input = os.path.join(PATH_DATASET_PREFIX, KEY_CIFAR_10, split)
    path_output = os.path.join(PATH_DATASET_PREFIX, cfg["output_folder_name"], split)
    os.makedirs(path_output, exist_ok=True)
    subset_ratio = cfg.get("test_subset", 1.0)

    class_dirs = sorted([
        d for d in os.listdir(path_input)
        if os.path.isdir(os.path.join(path_input, d))
    ])
    tasks = []

    all_class_paths = {
        cls_str: glob(os.path.join(path_input, cls_str, "*.png"))
        for cls_str in class_dirs
    }

    for cls_str in class_dirs:
        img_paths = all_class_paths[cls_str]
        if len(img_paths) == 0:
            print(f"[WARNING] Class {cls_str} has no images in split '{split}'. Skipping...")
            continue

        class_out_dir = os.path.join(path_output, cls_str)
        os.makedirs(class_out_dir, exist_ok=True)
        n_select = max(1, int(len(img_paths) * subset_ratio))
        selected_paths = np.random.default_rng(seed).choice(img_paths, size=n_select, replace=False)

        # Build donor pool
        if cfg["should_borrow_vn_from_single_class"]:
            donor_candidates = [d for d in class_dirs if d != cls_str and len(all_class_paths[d]) > 0]
            donor_class = np.random.default_rng(seed + int(cls_str)).choice(donor_candidates)
            donor_pool = all_class_paths[donor_class]
        else:
            donor_pool = [p for d in class_dirs if d != cls_str for p in all_class_paths[d]]

        for img_path in selected_paths:
            tasks.append((img_path, cls_str, donor_pool, cfg, path_output, seed))

    with Pool(min(CPU_CORE, cpu_count())) as pool:
        list(tqdm(pool.starmap(process_one_image, tasks), total=len(tasks)))

if __name__=='__main__':
    process_dataset("train", cfg)
    process_dataset("test", cfg)
