import os
import argparse
from PIL import Image
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as F
import torchvision.transforms as T
import random
import math
from multiprocessing import Pool, cpu_count


def parse_args():
    parser = argparse.ArgumentParser(description='Fast n-crop generator from ImageFolder dataset')
    parser.add_argument('--src_dir', type=str, required=True, help='Path to original ImageFolder dataset')
    parser.add_argument('--dst_dir', type=str, required=True, help='Path to save cropped dataset')
    parser.add_argument('--crop_size', type=int, default=224, help='Crop size for RandomResizedCrop')
    parser.add_argument('--num_crops', type=int, default=5, help='Number of crops per image')
    parser.add_argument('--scale_min', type=float, default=0.08, help='Min scale for RandomResizedCrop')
    parser.add_argument('--scale_max', type=float, default=1.0, help='Max scale for RandomResizedCrop')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of processes for multiprocessing')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    return parser.parse_args()


def get_random_crop_params(img, scale_range, ratio_range=(3. / 4., 4. / 3.)):
    """Copied from torchvision.transforms.RandomResizedCrop.get_params"""
    width, height = img.size
    area = height * width

    for _ in range(10):
        target_area = random.uniform(*scale_range) * area
        log_ratio = (math.log(ratio_range[0]), math.log(ratio_range[1]))
        aspect_ratio = math.exp(random.uniform(*log_ratio))

        w = int(round(math.sqrt(target_area * aspect_ratio)))
        h = int(round(math.sqrt(target_area / aspect_ratio)))

        if w <= width and h <= height:
            i = random.randint(0, height - h)
            j = random.randint(0, width - w)
            return i, j, h, w

    # fallback
    in_ratio = float(width) / float(height)
    if in_ratio < ratio_range[0]:
        w = width
        h = int(round(w / ratio_range[0]))
    elif in_ratio > ratio_range[1]:
        h = height
        w = int(round(h * ratio_range[1]))
    else:
        w = width
        h = height
    i = (height - h) // 2
    j = (width - w) // 2
    return i, j, h, w


def process_image(args):
    img_path, label_name, dst_dir, crop_size, num_crops, scale_min, scale_max = args
    try:
        img = Image.open(img_path).convert('RGB')
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        save_dir = os.path.join(dst_dir, label_name)
        os.makedirs(save_dir, exist_ok=True)

        for i in range(num_crops):
            top, left, height, width = get_random_crop_params(img, (scale_min, scale_max))
            crop = F.resized_crop(img, top, left, height, width, size=(crop_size, crop_size))
            crop_save_path = os.path.join(save_dir, f'{base_name}_crop{i}.jpg')
            crop.save(crop_save_path, quality=90, subsampling=0)

    except Exception as e:
        print(f"[Error] Failed to process {img_path}: {e}")


def main():
    args = parse_args()
    random.seed(args.seed)

    print("🔍 Scanning original dataset...")
    dataset = datasets.ImageFolder(root=args.src_dir)
    tasks = []

    for img_path, label in tqdm(dataset.imgs, desc="Preparing tasks"):
        label_name = dataset.classes[label]
        tasks.append((
            img_path,
            label_name,
            args.dst_dir,
            args.crop_size,
            args.num_crops,
            args.scale_min,
            args.scale_max
        ))

    print(f"🚀 Launching {args.num_workers} workers to generate crops...")
    with Pool(args.num_workers) as pool:
        list(tqdm(pool.imap_unordered(process_image, tasks), total=len(tasks)))

    print(f"\n✅ Done. All cropped images saved to: {args.dst_dir}")


if __name__ == '__main__':
    main()
