import numpy as np
import os
from tqdm import tqdm
import imageio
import skimage
import skimage.transform
import glob


DATA_DIR = 'ANONYMIZED_OMNIGLOT'

def l2_dist(x, y):
    return np.linalg.norm(x-y)

def fix_index(idx):
    return 0 if idx < 10 else idx-9

def read_omniglot():
    """Read omniglot dataset, save them to a single npy file"""
    omniglot_train = os.path.join(DATA_DIR, 'images_background_small1')
#    omniglot_eval = os.path.join(DATA_DIR, 'images_evaluation')

    data = []
    nc = 0
    na = 0
    nch = 0
    ni = 0
    yt = []
    cht = []
    # print(omniglot_train)
    for r in [omniglot_train]:
        classes = glob.glob(r + '/*')
        nc += 1
        for cls in tqdm(classes):
            alphabets = glob.glob(cls + '/*')
            na += 1
            for a in alphabets:
                nch += 1
                characters = glob.glob(a + '/*')
#                raws = []
                for ch in characters:  # 20 iters
                    ni += 1
                    raw = imageio.imread(ch)
                    raw = skimage.transform.resize(raw, (28, 28))
#                    raws.append(raw)
                    data.append(np.asarray(raw))
                    yt.append(na-1)
                    cht.append(nch-1)
    T = 50 # T in [5, 50] for training, and T = 5 for test.
    S = 80
    c = 10
    tot = len(data)
    for k in range(c):
        for i in range(T):
            s = np.random.permutation(tot)
            xx = []
            yy = []
            ss = 0
            for j in s:
                if cht[j]!=k and cht[j] not in [10,11,12,13]:
                    continue
                xx.append(data[j])
                yy.append(fix_index(cht[j]))
                ss += 1
                if ss == S:
                    break
            D = np.zeros((S, S))
            # print(S) # S = 80, right
            # print(xx) # xx = [], wrong
            for a in range(S):
                for b in range(S):
                    # print(a,b)
                    D[a][b] = l2_dist(xx[a],xx[b])
            np.savetxt(DATA_DIR+"/D-"+str(k)+"-"+str(i)+"-of-"+str(T-1), D, delimiter=',')
            np.savetxt(DATA_DIR+"/y-"+str(k)+"-"+str(i)+"-of-"+str(T-1), yy, delimiter=',')
#    os.makedirs(DATA_DIR, exist_ok=True)
    return data

if __name__ == '__main__':
    read_omniglot()
