# -*- coding: utf-8 -*-
""" 
	BASED ON THE CODE OF:
    	Simone Scardapane, Group LASSO regularization for neural networks (Theano/Lasagne)
    	Preprint: https://arxiv.org/abs/1607.00485
		https://bitbucket.org/ispamm/group-lasso-deep-networks/src/master/
"""

# Necessary imports
import adaptos_optimizer
import argparse
import lasagne
from lasagne.nonlinearities import leaky_rectify, softmax
import numpy as np
from tabulate import tabulate
import pickle
import sklearn.datasets, sklearn.preprocessing, sklearn.model_selection
import time
import tensorflow as tf
import theano, theano.tensor as T

parser = argparse.ArgumentParser(description='Experiments for Group Lasso')
parser.add_argument('--dataset', required=True, choices=["digits", "mnist"], help="dataset to be used (digits or mnist)")
parser.add_argument('--lr', required=False, type=float, help="learning rate")
parser.add_argument('--optimizer', required=True, choices=["sgd", "adagrad", "adam", "adaptos", "tos"], help="optimizer to use (sgd, adagrad, adam, adaptos, tos)")
parser.add_argument('--seed', required=True, type=int, help="random seed")
args = parser.parse_args()

update = args.optimizer
lr = args.lr
dataset = args.dataset
seed = args.seed

best_lrs_digits = {"sgd": 0.1, "adagrad": 0.1, "adam": 0.01, "adaptos": 1.0, "tos": 1.0}
best_lrs_mnist = {"sgd": 0.01, "adagrad": 0.01, "adam": 0.001, "adaptos": 1.0, "tos": 1.0}

if lr == None:
    if dataset == "mnist":
        lr = best_lrs_mnist[update]
    elif dataset == "digits":
        lr = best_lrs_digits[update]

np.random.seed(seed)
r = np.random.RandomState(seed)
lasagne.random.set_rng(r)

# set hyperparameters
if dataset == "digits":
    max_epochs = 3000
    batch_size = 300
    neurons_per_layer = np.asarray([40, 20])
    reg_factor = 10**-3
elif dataset == "mnist":
    max_epochs = 500
    batch_size = 400
    neurons_per_layer = np.asarray([400, 300, 100])
    reg_factor = 10**-4

# load data
if dataset == "digits":
    digits = sklearn.datasets.load_digits()
    X = digits.data
    y = digits.target
elif dataset == "mnist":
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    X = np.vstack((x_train, x_test)).reshape(-1, 784)
    y = np.hstack((y_train, y_test))

# preprocess datasets
scaler = sklearn.preprocessing.MinMaxScaler()
X = scaler.fit_transform(X)

# train-test split
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, test_size=0.25)

# define the input and output symbolic variables
input_var = T.matrix(name='X')
target_var = T.ivector(name='y')

# Utility function for minibatches
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]

# Define the network structure
network = lasagne.layers.InputLayer((None, X.shape[1]), input_var)
for h in neurons_per_layer:
    network = lasagne.layers.DenseLayer(network, h, nonlinearity=leaky_rectify, W=lasagne.init.GlorotNormal())
network = lasagne.layers.DenseLayer(network, len(np.unique(y)), nonlinearity=softmax, W=lasagne.init.GlorotNormal())
params_original = lasagne.layers.get_all_param_values(network)    
params = lasagne.layers.get_all_params(network, trainable=True)
lasagne.layers.set_all_param_values(network, params_original)        

# Define the loss function
prediction = lasagne.layers.get_output(network)
loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)

# Define the regularized loss function
def groupl1(x): # group lasso
    return T.sum(T.sqrt(x.shape[1])*T.sqrt(T.sum(x**2, axis=1)))
regularizer = lambda x:lasagne.regularization.l1(x)+groupl1(x)
loss_reg = loss.mean() + reg_factor * lasagne.regularization.regularize_network_params(network, regularizer)
loss_unreg = loss.mean()

# Update function
updates_reg = {
    "sgd": lasagne.updates.sgd(loss_reg, params, learning_rate=lr),
    "adagrad": lasagne.updates.adagrad(loss_reg, params, learning_rate=lr),
    "adam": lasagne.updates.adam(loss_reg, params, learning_rate=lr),
    "adaptos": adaptos_optimizer.adaptos(loss_unreg, params, learning_rate=lr, mu=reg_factor),
    "tos": adaptos_optimizer.tos(loss_unreg, params, learning_rate=lr, mu=reg_factor),
}

# training function
train_fn = theano.function([input_var, target_var], loss_reg, updates=updates_reg[update], allow_input_downcast=True)

# loss function: equivalent to training function without updating
loss_fn = theano.function([input_var, target_var], loss_reg, allow_input_downcast=True)

# test function
test_prediction = lasagne.layers.get_output(network, deterministic=True)
test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var),
                        dtype=theano.config.floatX)
test_fn = theano.function([input_var, target_var], test_acc, allow_input_downcast=True)    

# Things to keep track of while training
tr_obj = np.zeros(max_epochs)
te_obj = np.zeros(max_epochs)
tr_acc = np.zeros(max_epochs)
te_acc = np.zeros(max_epochs)
times = np.zeros(max_epochs)

# Train network
start = time.time()
for epoch in range(max_epochs):
    # print for sanity
    if np.mod(epoch, 100) == 0:
        print(epoch, "... ", end="")

    # keep track of training loss
    train_loss_total = 0
    train_acc_total = 0
    for n, (input_batch, target_batch) in enumerate(iterate_minibatches(X_train, y_train, batch_size, shuffle=True)):
        train_loss_total += train_fn(input_batch, target_batch)
        train_acc_total += test_fn(input_batch, target_batch)
    tr_obj[epoch] = train_loss_total/(n+1)
    tr_acc[epoch] = train_acc_total/(n+1)

    # keep track of testing loss
    test_loss_total = 0
    test_acc_total = 0
    for n, (input_batch, target_batch) in enumerate(iterate_minibatches(X_test, y_test, batch_size, shuffle=True)):
        test_loss_total += loss_fn(input_batch, target_batch)
        test_acc_total += test_fn(input_batch, target_batch)
    te_obj[epoch] = test_loss_total/(n+1)
    te_acc[epoch] = test_acc_total/(n+1)

    times[epoch] = time.time() - start

# extract all the parameters
params_trained = lasagne.layers.get_all_param_values(network, trainable=True)

data = {}
data["tr_obj"] = tr_obj
data["te_obj"] = te_obj
data["tr_acc"] = tr_acc
data["te_acc"] = te_acc
data["times"] = times
data["params"] = params_trained

save_file = f"runs/{dataset}_{update}_{lr}_{seed}.pkl"
with open(save_file, 'wb') as f:
    pickle.dump(data, f)

# import matplotlib
# import matplotlib.pyplot as plt
# matplotlib.use('Agg')
# plt.figure()
# plt.title('')
# plt.plot(tr_obj, label='train obj')
# plt.plot(te_obj, label='test obj')
# plt.legend()
# plt.savefig("obj.png")

# plt.figure()
# plt.title('')
# plt.plot(tr_acc, label='train acc')
# plt.plot(te_acc, label='test acc')
# plt.legend()
# plt.savefig("acc.png")