import numpy as np
from os import listdir
from os.path import isfile, isdir, join
import os
import json
import random
import re

cwd = os.getcwd()
print(cwd)
data_path = "D:/Dataset/miniImagenet/images/"
# data_path = "E:/GitHub/Embedded_AI/MultiProj/MethodDesign/datasets/data/mini-imagenet/images"

dataset_list = ['base', 'val', 'novel']

# if not os.path.exists(savedir):
#    os.makedirs(savedir)

cl = -1
folderlist = []

datasetmap = {'base': 'train', 'val': 'val', 'novel': 'test'}
filelists = {'base': {}, 'val': {}, 'novel': {}}
filelists_flat = {'base': [], 'val': [], 'novel': []}
labellists_flat = {'base': [], 'val': [], 'novel': []}

for dataset in dataset_list:
    with open(datasetmap[dataset] + ".csv", "r") as lines:
        for i, line in enumerate(lines):
            if i == 0:
                continue
            fid, _, label = re.split(',|\.', line)
            label = label.replace('\n', '')
            if not label in filelists[dataset]:
                folderlist.append(label)
                filelists[dataset][label] = []
                fnames = listdir(join(data_path, label))
                fname_number = []
                for fname in fnames:
                    # temp = os.path.join(data_path, label, fname)
                    temp = os.path.join(label, fname)
                    temp = "/".join(temp.split("\\"))
                    fname_number.append(temp)

                filelists[dataset][label] = fname_number

    for key, filelist in filelists[dataset].items():
        cl += 1
        random.shuffle(filelist)
        filelists_flat[dataset] += filelist
        labellists_flat[dataset] += np.repeat(cl, len(filelist)).tolist()

for dataset in dataset_list:
    fo = open(dataset + ".json", "w")
    fo.write('{"label_names": [')
    fo.writelines(['"%s",' % item for item in folderlist])
    fo.seek(0, os.SEEK_END)
    fo.seek(fo.tell() - 1, os.SEEK_SET)
    fo.write('],')

    fo.write('"image_names": [')
    fo.writelines(['"%s",' % item for item in filelists_flat[dataset]])
    fo.seek(0, os.SEEK_END)
    fo.seek(fo.tell() - 1, os.SEEK_SET)
    fo.write('],')

    fo.write('"image_labels": [')
    fo.writelines(['%d,' % item for item in labellists_flat[dataset]])
    fo.seek(0, os.SEEK_END)
    fo.seek(fo.tell() - 1, os.SEEK_SET)
    fo.write(']}')

    fo.close()
    print("%s -OK" % dataset)