import argparse
import glob
import os
import random
import shutil

import tqdm


def main():
    parser = argparse.ArgumentParser(description="Shuffle ImageNet labels")
    parser.add_argument(
        "--source-data",
        required=True,
        help="directory containing ImageNet data suitable for torchvision.datasets",
    )
    parser.add_argument(
        "--target-data",
        required=True,
        help="directory where shuffled ImageNet data will be written",
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed")
    args = parser.parse_args()

    if os.path.exists(args.target_data) and os.listdir(args.target_data):
        raise ValueError(f"Target directory {args.target_data} is not empty")
    os.makedirs(args.target_data, exist_ok=True)

    for split in ["train", "val"]:
        # Go through all files in all classes and re-assign them to random classes
        # This is done by moving the files to a new directory with a random class name
        # At the same time, the distribution of classes is preserved.
        available_wnids = []
        for wnid in os.listdir(os.path.join(args.source_data, split)):
            available_wnids += [wnid] * len(
                os.listdir(os.path.join(args.source_data, split, wnid))
            )
            os.makedirs(os.path.join(args.target_data, split, wnid), exist_ok=False)
        random.seed(args.seed)
        random.shuffle(available_wnids)

        pbar = tqdm.tqdm(total=len(available_wnids), position=0, leave=True)
        for wnid in os.listdir(os.path.join(args.source_data, split)):
            for filename in glob.glob(
                os.path.join(args.source_data, split, wnid, "*.JPEG")
            ):
                new_wnid = available_wnids.pop()
                new_filename = os.path.join(
                    args.target_data, split, new_wnid, os.path.basename(filename)
                )
                shutil.copyfile(filename, new_filename)
                pbar.update(1)


if __name__ == "__main__":
    main()
