import os
import argparse
from engine.datasets import dataset_classes

from engine.tools.utils import makedirs, save_as_json, set_random_seed
from engine.datasets.utils import get_few_shot_setup_name
from engine.datasets.benchmark import generate_fewshot_dataset

parser = argparse.ArgumentParser()
parser.add_argument(
    "--data_dir", type=str, default="./data", help="where the dataset is saved",
)
parser.add_argument(
    "--indices_dir", type=str, default="./indices", help="where the (few-shot) indices is saved",
)
parser.add_argument(
    "--dataset", type=str, default="imagenet", help="Name of the dataset",
)
parser.add_argument(
    "--dataset", type=str, default="", choices=dataset_classes.keys(),
    help="dataset name",
)
parser.add_argument(
    "--shot", type=int, default=4, choices=[1, 2, 4, 8, 16],
    help="train shot number. note that val shot is automatically set to min(4, shot)",
)
parser.add_argument(
    "--seed", type=int, default=1, help="seed number",
)
args = parser.parse_args()

def main(args):
    if args.seed >= 0:
        print("Setting fixed seed: {}".format(args.seed))
        set_random_seed(args.seed)

    # Check if the dataset is supported
    assert args.dataset in dataset_classes
    few_shot_index_file = os.path.join(
        args.indices_dir,
        args.dataset,
        get_few_shot_setup_name(args.shot, args.seed) + ".json"
    )
    if os.path.exists(few_shot_index_file):
        # If the json file exists, then load it
        print(f"Few-shot data exists at {few_shot_index_file}.")
    else:
        # If the json file does not exist, then create it
        print(f"Few-shot data does not exist at {few_shot_index_file}. Sample a new split.")
        makedirs(os.path.dirname(few_shot_index_file))
        benchmark = dataset_classes[args.dataset](args.data_dir)
        few_shot_dataset = generate_fewshot_dataset(
            benchmark.train,
            benchmark.val,
            num_shots=args.shot,
            max_val_shots=4,
        )
        save_as_json(few_shot_dataset, few_shot_index_file)


if __name__ == "__main__":
    main(args)
    
