import torch
import numpy as np
import torch.nn as nn
import torch.functional as F

data_name = 'hmm_cond_prob_feature.npy'
label_name = 'hmm_cond_prob_label.npy'

n_epoch = 300
batchsize = 256
lr = 1e-3

data = np.load(data_name, allow_pickle=True).astype('float32')
label = np.load(label_name).astype('int64')

def shuffle(X, Y):
    m = X.shape[0]
    permutation = list(np.random.permutation(m))
    X_shuffle = X[permutation]
    Y_shuffle = Y[permutation]
    shuffles = {"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}
    return shuffles

def get_mini_batches(X, Y, mini_batch_size):
    shuffles = shuffle(X, Y)
    num_examples = shuffles["X_shuffle"].shape[0]
    num_complete =  num_examples // mini_batch_size
    mini_batches = []
    for i in range(num_complete):
        mini_batches.append([shuffles["X_shuffle"][i * mini_batch_size:(i + 1) * mini_batch_size], shuffles["Y_shuffle"][i * mini_batch_size:(i + 1) * mini_batch_size]])

    if 0 == num_examples % mini_batch_size:
        pass
    else:
        mini_batches.append([shuffles["X_shuffle"][num_complete * mini_batch_size:], shuffles["Y_shuffle"][num_complete * mini_batch_size:]])
    return mini_batches

mini_batches = get_mini_batches(data, label, batchsize)

ent_list = []
for batch in mini_batches:
    for j in range(16):
        vec = batch[0][:,200 * j: 200 * j + 200]
        ent = np.sum(np.log(vec + 1e-15) * vec, axis = -1)
    ent_list.append(ent)

print(np.mean(np.concatenate(ent_list)))