import os
import numpy as np
import time
import tensorflow as tf

def ball(path, T=20, seed=0, ratio=0.8):
    start = time.time()
    path = path + '/'
    try:
        # data_train1 = np.load(path + 'images_train_ball.npy')
        # data_train2 = np.load(path + 'images_train_box.npy')
        # data_train = np.concatenate([data_train1, data_train2], 0)
        
        # data_test1 = np.load(path + 'images_test_ball.npy')
        # data_test2 = np.load(path + 'images_test_box.npy')
        # data_test = np.concatenate([data_test1, data_test2], 0)

        data_train = np.load(path + 'images_train_ball.npy')
        data_test = np.load(path + 'images_test_ball.npy')
        
        np.random.seed(seed)
        ind = np.random.permutation(range(data_train.shape[0]))
        data_train = data_train[ind]
        ind = np.random.permutation(range(data_test.shape[0]))
        data_test = data_test[ind]
        print('KVAE bouncing shape loaded, now processing...')
    except:
        raise ValueError
    # get the first 5 frame
    if T == 1:
        data_train = data_train[:, 0]
        data_test = data_test[:, 0]
    else:
        data_train = data_train[:, :T]
        data_test = data_test[:, :T]
    data_train = data_train[:, :, :, :, np.newaxis]
    data_test = data_test[:, :, :, :, np.newaxis]
    end = time.time()
    print(data_train.shape)
    print('KVAE bouncing shape loaded and processed in %.2f seconds...' % (end - start))
    return data_train, data_test

def sprites_act(path, seed=0, return_labels = False):
    directions = ['front', 'left', 'right']#, 'back']
    actions = ['walk', 'spellcard', 'slash']
    start = time.time()
    path = path
    X_train = []
    X_test = []
    if return_labels:
        A_train = []; A_test = []
        D_train = []; D_test = []
    for act in range(len(actions)):
        for i in range(len(directions)):
            label = 3 * act + i  
            print(actions[act], directions[i], act, i, label)
            x = np.load(path + '%s_%s_frames_train.npy' % (actions[act], directions[i]))
            X_train.append(x)
            y = np.load(path + '%s_%s_frames_test.npy' % (actions[act], directions[i]))
            X_test.append(y)
            if return_labels:
                a = np.load(path + '%s_%s_attributes_train.npy' % (actions[act], directions[i]))
                A_train.append(a)                
                d = np.zeros([a.shape[0], a.shape[1], 9])
                d[:, :, label] = 1; D_train.append(d)
                
                a = np.load(path + '%s_%s_attributes_test.npy' % (actions[act], directions[i]))
                A_test.append(a)
                d = np.zeros([a.shape[0], a.shape[1], 9])
                d[:, :, label] = 1; D_test.append(d)
                
        
    X_train = np.concatenate(X_train, axis=0)
    X_test = np.concatenate(X_test, axis=0)
    np.random.seed(seed)
    ind = np.random.permutation(X_train.shape[0])
    X_train = X_train[ind]
    if return_labels:
        A_train = np.concatenate(A_train, axis=0)
        D_train = np.concatenate(D_train, axis=0)
        A_train = A_train[ind]
        D_train = D_train[ind]
    ind = np.random.permutation(X_test.shape[0])
    X_test = X_test[ind]
    if return_labels:
        A_test = np.concatenate(A_test, axis=0)
        D_test = np.concatenate(D_test, axis=0)
        A_test = A_test[ind]
        D_test = D_test[ind]
        print(A_test.shape, D_test.shape, X_test.shape, 'shapes')
    print(X_train.shape, X_test.min(), X_test.max())
    end = time.time()
    print('data loaded in %.2f seconds...' % (end - start))
    
    if return_labels:
        return X_train, X_test, A_train, A_test, D_train, D_test
    else:
        return X_train, X_test

def moving_mnist(path, T = 5, seed=0, ratio=0.9, conv=True):
    start = time.time()
    path = path + 'movingMNIST/'
    try:
        data = np.load(path + 'mnist_test_seq.npy')
        print('moving MNIST loaded, now processing...')
    except:
        raise ValueError
    # get the first 5 frame
    data = data[:T]
    print(data.shape)
    data = data.transpose(1, 0, 2, 3)   # shape (N, T, width, height)
    # normalise pixel into [0, 1]
    data = np.asarray(data, dtype='f') / float(np.max(data))
    # split to have 90% training and 10% testing
    N = data.shape[0]
    if not conv:
        # for MLP
        data = data.reshape(N, T, -1)
    else:
        # for conv nets
        data = np.expand_dims(data, axis=4)
    if T == 1:
        data = data[:, 0]
    np.random.seed(seed)
    ind = np.random.permutation(range(N))
    ind_train = ind[:int(N*ratio)]
    ind_test = ind[int(N*ratio):]
    end = time.time()
    print(data.shape)
    print('moving MNIST loaded and processed in %.2f seconds...' % (end - start))
    return data[ind_train], data[ind_test]

