#!/usr/bin/env python
# -*-coding:utf-8 -*-
import numpy as np 
import os 
import utils.vis_utils as vu 


def get_test_index(latent_class, num_tt):
    shape_index, scale_index, rotation_index = 1, 2, 3
    test_index = []
    unique_index = [np.unique(latent_class[:, v]) for v in [shape_index, scale_index, rotation_index]]
    for i in unique_index[0]:
        _index = np.where(latent_class[:, shape_index] == i)[0]
        for j in unique_index[1]:
            __index = _index[np.where(latent_class[_index, scale_index] == j)[0]]
            for q in unique_index[2]:
                ___index = __index[np.where(latent_class[__index, rotation_index] == q)[0]]
                test_index.append(___index[(np.linspace(0, len(___index), num_tt+2)[1:-1]).astype(np.int32)])
    test_index = np.concatenate(test_index, axis=0)
    return test_index, np.delete(np.arange(len(latent_class)), test_index)



def get_subset_index(latent_class, skip=2):
    shape_index, scale_index, rotation_index = 1, 2, 3
    tr_index = []
    unique_index = [np.sort(np.unique(latent_class[:, v])) for v in [shape_index, scale_index, rotation_index]]
    for i in unique_index[0]:
        _index1 = np.where(latent_class[:, shape_index] == i)[0]
        for j in unique_index[1]:
            _index2 = _index1[np.where(latent_class[_index1, scale_index] == j)[0]]
            for q in unique_index[2]:
                _index3 = _index2[np.where(latent_class[_index2, rotation_index] == q)[0]]
                for m in np.sort(np.unique(latent_class[_index3, -2]))[::skip]:
                    _index4 = _index3[np.where(latent_class[_index3, -2] == m)[0]]
                    for n in np.sort(np.unique(latent_class[_index4, -1]))[::skip]:
                        _index5 = _index4[np.where(latent_class[_index4, -1] == n)[0]]
                        tr_index.append(_index5)
    return np.concatenate(tr_index, axis=0)
    

def show_train_and_test(tr_index, tt_index, latent_class, images, s_shape, s_rotation, s_scale, show=True):
    specified_class = [0, s_shape, s_rotation, s_scale]
    for j, s_index in enumerate([tr_index, tt_index]):
        index = s_index.copy()
        for i in range(4)[1:]:        
            s_index = np.where(latent_class[index, i] == specified_class[i])[0]
            index = index[s_index]
        if j == 0:
            tr_subset = index.copy()
        else:
            tt_subset = index.copy()
    tr_image_show = images[tr_subset[np.random.choice(len(tr_subset), 100, replace=False)]]
    tt_image_show = images[tt_subset]
    print(np.shape(tr_image_show), np.shape(tt_image_show))
    _ = vu.create_canvas(tr_image_show, 10, show)
    _ = vu.create_canvas(tt_image_show, 4, show)
    
    
def prepare_data(path="../image_dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz", num_tt=32, opt="train_vae"):
    dataset_zip = np.load(path, allow_pickle=True)
    latents_classes = dataset_zip['latents_classes']
    if opt == "train_vae" or opt == "train_im":
        imgs = dataset_zip['imgs']
    else:
        imgs = np.zeros([len(latents_classes)])
        
    subset_index = get_subset_index(latents_classes, 2)
    
    imgs = imgs[subset_index]
    latents_classes = latents_classes[subset_index]
    
    tt_index, tr_index = get_test_index(latents_classes, num_tt)
    tr_imgs = imgs[tr_index]
    tt_imgs = imgs[tt_index]
    return tr_imgs.astype(np.float32), latents_classes[tr_index], tt_imgs.astype(np.float32), latents_classes[tt_index]


def print_info(tr_label):
    index = ["color", "shape", "scale", "rotation"]
    for i, s_index in enumerate(index):
        print(s_index, np.unique(tr_label[:, i], return_counts=True))
            
