#!/usr/bin/env python3
# examples/images/cifar10/export_cifar_lt.py
import os
import sys
import shutil
import argparse
from typing import List

from torchvision import transforms
from torchvision.utils import save_image

#    conditional-flow-matching  
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.join(current_dir, '../../../')
sys.path.insert(0, os.path.abspath(root_dir))

from torchcfm.utils import ImbalanceCIFAR10, ImbalanceCIFAR100
from torchvision import datasets


def build_transform(resize: int) -> transforms.Compose:
    tfms: List[transforms.transforms] = []
    if resize and resize > 0:
        tfms.append(transforms.Resize((resize, resize)))
    tfms.append(transforms.ToTensor())
    return transforms.Compose(tfms)


def main():
    p = argparse.ArgumentParser("Export CIFAR images to a folder")
    p.add_argument("--dataset", type=str, required=True, choices=["cifar10", "cifar100", "cifar10_lt", "cifar100_lt"], help=" ")
    p.add_argument("--out_dir", type=str, required=True, help=" ( )")
    p.add_argument("--root", type=str, default="./data", help="torchvision  ")
    p.add_argument("--split", type=str, default="train", choices=["train", "test"], help="train/test ")
    p.add_argument("--imb_type", type=str, default="exp", choices=["exp", "step"], help="imbalance ")
    p.add_argument("--imb_factor", type=float, default=0.01, help="imbalance ( )")
    p.add_argument("--rand_number", type=int, default=0, help="long-tail ")
    p.add_argument("--num_images", type=int, default=-1, help="   (-1 )")
    p.add_argument("--resize", type=int, default=32, help=" (<=0 )")
    p.add_argument("--per_class_subdir", action="store_true", help="   ")
    p.add_argument("--overwrite", action="store_true", help="out_dir   ")
    args = p.parse_args()

    # 
    if args.overwrite and os.path.exists(args.out_dir):
        shutil.rmtree(args.out_dir)
    os.makedirs(args.out_dir, exist_ok=True)

    tfm = build_transform(args.resize)
    
    #  
    if args.dataset == "cifar10":
        ds = datasets.CIFAR10(
            root=args.root,
            train=(args.split == "train"),
            transform=tfm,
            download=True,
        )
    elif args.dataset == "cifar100":
        ds = datasets.CIFAR100(
            root=args.root,
            train=(args.split == "train"),
            transform=tfm,
            download=True,
        )
    elif args.dataset == "cifar10_lt":
        ds = ImbalanceCIFAR10(
            root=args.root,
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=(args.split == "train"),
            transform=tfm,
            download=True,
        )
    else:  # cifar100_lt
        ds = ImbalanceCIFAR100(
            root=args.root,
            imb_type=args.imb_type,
            imb_factor=args.imb_factor,
            rand_number=args.rand_number,
            train=(args.split == "train"),
            transform=tfm,
            download=True,
        )

    #   1     
    if args.per_class_subdir:
        num_classes = 10 if args.dataset in ["cifar10", "cifar10_lt"] else 100
        
        #       
        class_indices = [[] for _ in range(num_classes)]
        for i in range(len(ds)):
            _, label = ds[i]
            class_indices[label].append(i)
        
        #      
        actual_counts = [len(indices) for indices in class_indices]
        print(f"[Info] Original class distribution: {actual_counts}")
        
        #    
        empty_classes = [i for i, count in enumerate(actual_counts) if count == 0]
        if empty_classes:
            print(f"[Warning] Classes with no images: {empty_classes}")
        
        #    1 ,     
        min_per_class = 1
        total_available = len(ds)
        
        #      
        save_counts = [0] * num_classes
        
        # 1:    1  (  )
        for cls_idx in range(num_classes):
            if actual_counts[cls_idx] > 0:
                save_counts[cls_idx] = min_per_class
        
        # 2:      
        allocated = sum(save_counts)
        remaining = total_available - allocated
        if remaining > 0:
            #    (  )
            total_non_empty = sum(count for count in actual_counts if count > 0)
            for cls_idx in range(num_classes):
                if actual_counts[cls_idx] > 0:
                    ratio = actual_counts[cls_idx] / total_non_empty
                    additional = int(remaining * ratio)
                    save_counts[cls_idx] += min(additional, actual_counts[cls_idx] - save_counts[cls_idx])
        
        print(f"[Info] Target save distribution: {save_counts}")
        
        #    
        for cls_idx in range(num_classes):
            cls_dir = os.path.join(args.out_dir, f"{cls_idx:02d}")
            os.makedirs(cls_dir, exist_ok=True)

    limit = min(args.num_images if args.num_images > 0 else len(ds), len(ds))
    print(f"[Export] dataset={args.dataset}, split={args.split}, total={len(ds)}, save={limit}, out_dir={args.out_dir}")

    saved = 0
    if args.per_class_subdir:
        #    
        class_saved_counts = [0] * num_classes
        for cls_idx in range(num_classes):
            if save_counts[cls_idx] > 0:
                indices = class_indices[cls_idx][:save_counts[cls_idx]]
                for idx in indices:
                    img, label = ds[idx]
                    cls_dir = os.path.join(args.out_dir, f"{label:02d}")
                    #    
                    suffix = "_lt" if args.dataset.endswith("_lt") else ""
                    path = os.path.join(cls_dir, f"{args.split}{suffix}_{class_saved_counts[label]:06d}.png")
                    save_image(img, path)
                    class_saved_counts[label] += 1
                    saved += 1
                    if saved % 1000 == 0:
                        print(f"  saved {saved}/{sum(save_counts)}")
        
        print(f"[Info] Final class distribution: {class_saved_counts}")
    else:
        #   (per_class_subdir  )
        for i in range(limit):
            img, label = ds[i]
            #    
            suffix = "_lt" if args.dataset.endswith("_lt") else ""
            path = os.path.join(args.out_dir, f"{args.split}{suffix}_{i:06d}.png")
            save_image(img, path)
            saved += 1
            if saved % 1000 == 0:
                print(f"  saved {saved}/{limit}")
    
    print(f"[Done] saved {saved} images to {args.out_dir}")


if __name__ == "__main__":
    main()