import torch
import numpy as np
import torch.nn as nn
import torch.functional as F

data_name = 'features.npy'
label_name = 'labels.npy'

n_epoch = 1000
batchsize = 256
lr = 1e-3

data = np.load(data_name, allow_pickle=True).astype('float32')
label = np.load(label_name).astype('int64')

d_model = data.shape[-1]
n_class = np.max(label) + 1

model = nn.Linear(d_model, n_class).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr, 1e-4)
criterion = nn.CrossEntropyLoss()

def shuffle(X, Y):
    m = X.shape[0]
    permutation = list(np.random.permutation(m))
    X_shuffle = X[permutation]
    Y_shuffle = Y[permutation]
    shuffles = {"X_shuffle": X_shuffle, "Y_shuffle": Y_shuffle}
    return shuffles

def get_mini_batches(X, Y, mini_batch_size):
    shuffles = shuffle(X, Y)
    num_examples = shuffles["X_shuffle"].shape[0]
    num_complete =  num_examples // mini_batch_size
    mini_batches = []
    for i in range(num_complete):
        mini_batches.append([shuffles["X_shuffle"][i * mini_batch_size:(i + 1) * mini_batch_size], shuffles["Y_shuffle"][i * mini_batch_size:(i + 1) * mini_batch_size]])

    if 0 == num_examples % mini_batch_size:
        pass
    else:
        mini_batches.append([shuffles["X_shuffle"][num_complete * mini_batch_size:], shuffles["Y_shuffle"][num_complete * mini_batch_size:]])
    return mini_batches

for epoch in range(n_epoch):
    mini_batches = get_mini_batches(data, label, batchsize)
    for i, (feature, label_int) in enumerate(mini_batches):
        if np.sum(np.isnan(feature)) != 0 :
            continue
        feature = torch.tensor(feature).cuda()
        label_int = torch.tensor(label_int).cuda()
        loss = criterion(model(feature), label_int)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 1:
            print('epoch='+str(epoch)+' step='+str(i)+' loss='+str(loss.data.cpu().numpy()))
        
correct_num = 0
total_num = 0
with torch.no_grad():
    for (i, (im, label_i)) in enumerate(mini_batches):
        fs = torch.tensor(im).cuda()
        label_i = torch.tensor(label_i).cuda()

        out = model.forward(fs)

        predict_prob, label_i = [x.data.cpu().numpy() for x in (out, label_i)]

        predict_index = np.argmax(predict_prob, axis=-1)
        correct_num += np.sum(label_i == predict_index)
        total_num += fs.shape[0]
        if i % 1000 == 1:
            print(i)

print(float(correct_num) / total_num)