# -*- coding: utf-8 -*-
"""
This script generates a miniImageNet pickle file from the yaoyao-liu
mini-imagenet-tools dataset (Google Drive link)

It contains 100 classes of 84x84 images defined by Ravi & Larochelle,
split into train/test/val for few-shot learning. This script merges all splits
into a single dataset (600 images per class) following Tim Hess' repository:

https://github.com/TimmHess/TwoComplementaryPerspectivesCL/tree/main/src/benchmarks

Usage:
    python make_mini_imagenet_2pkl.py --root_dir ./miniImageNet --final_name miniImageNet.pkl
"""

import argparse
import os
import pickle
import cv2
from tqdm import tqdm


def main(root_dir: str, final_name: str):
    data_dirs = []

    for dir in os.listdir(root_dir):
        dir_path = os.path.join(root_dir, dir)
        if os.path.isdir(dir_path):
            for sub_dir in os.listdir(dir_path):
                sub_path = os.path.join(dir_path, sub_dir)
                if os.path.isdir(sub_path):
                    data_dirs.append(sub_path + "/")

    print(f"Found {len(data_dirs)} class folders.")

    dataset = {"data": [], "labels": []}
    for i, dir in tqdm(enumerate(data_dirs), total=len(data_dirs), desc="Classes"):
        sample_paths = os.listdir(dir)
        for sample in tqdm(sample_paths, desc=f"Class {i}", leave=False):
            sample_path = os.path.join(dir, sample)
            if os.path.isfile(sample_path):
                img = cv2.imread(sample_path)
                dataset["data"].append(img)
                dataset["labels"].append(i)

    output_path = os.path.join(root_dir, final_name)
    print(f"Storing dataset to {output_path}")
    with open(output_path, 'wb') as handle:
        pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate miniImageNet.pkl from dataset folders.")
    parser.add_argument("--root_dir", type=str, required=True, help="Root directory of miniImageNet dataset")
    parser.add_argument("--final_name", type=str, required=True, help="Name of the output pickle file")
    args = parser.parse_args()

    main(args.root_dir, args.final_name)
