#!/usr/bin/env python

"""
Lasagne implementation of CIFAR-10 examples from "Deep Residual Learning for Image Recognition" (http://arxiv.org/abs/1512.03385)

Check the accompanying files for pretrained models. The 32-layer network (n=5), achieves a validation error of 7.42%, 
while the 56-layer network (n=9) achieves error of 6.75%, which is roughly equivalent to the examples in the paper.
"""

from __future__ import print_function
import adaptos_optimizer

import sys
import os
import time
import pickle

import numpy as np
import theano
import theano.tensor as T
import lasagne

# for the larger networks (n>=9), we need to adjust pythons recursion limit
sys.setrecursionlimit(10000)

# ##################### Load data from CIFAR-10 dataset #######################
# this code assumes the cifar dataset from 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
# has been extracted in current working directory

def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict

def load_data():
    xs = []
    ys = []
    for j in range(5):
      d = unpickle(f'cifar-10-batches-py/data_batch_{j+1}')
      x = d['data']
      y = d['labels']
      xs.append(x)
      ys.append(y)

    d = unpickle('cifar-10-batches-py/test_batch')
    xs.append(d['data'])
    ys.append(d['labels'])

    x = np.concatenate(xs)/np.float32(255)
    y = np.concatenate(ys)
    x = np.dstack((x[:, :1024], x[:, 1024:2048], x[:, 2048:]))
    x = x.reshape((x.shape[0], 32, 32, 3)).transpose(0,3,1,2)

    # subtract per-pixel mean
    pixel_mean = np.mean(x[0:50000],axis=0)
    #pickle.dump(pixel_mean, open("cifar10-pixel_mean.pkl","wb"))
    x -= pixel_mean

    # create mirrored images
    X_train = x[0:50000,:,:,:]
    Y_train = y[0:50000]
    X_train_flip = X_train[:,:,:,::-1]
    Y_train_flip = Y_train
    X_train = np.concatenate((X_train,X_train_flip),axis=0)
    Y_train = np.concatenate((Y_train,Y_train_flip),axis=0)

    X_test = x[50000:,:,:,:]
    Y_test = y[50000:]

    return dict(
        X_train=lasagne.utils.floatX(X_train),
        Y_train=Y_train.astype('int32'),
        X_test = lasagne.utils.floatX(X_test),
        Y_test = Y_test.astype('int32'),)

# ##################### Build the neural network model #######################

from lasagne.layers import Conv2DLayer as ConvLayer
#from lasagne.layers.dnn import Conv2DDNNLayer as ConvLayer
from lasagne.layers import ElemwiseSumLayer
from lasagne.layers import InputLayer
from lasagne.layers import DenseLayer
from lasagne.layers import GlobalPoolLayer
from lasagne.layers import PadLayer
from lasagne.layers import ExpressionLayer
from lasagne.layers import NonlinearityLayer
from lasagne.nonlinearities import softmax, rectify
from lasagne.layers import batch_norm

def build_cnn(input_var=None, n=5):
    
    # create a residual learning building block with two stacked 3x3 convlayers as in paper
    def residual_block(l, increase_dim=False, projection=False):
        input_num_filters = l.output_shape[1]
        if increase_dim:
            first_stride = (2,2)
            out_num_filters = input_num_filters*2
        else:
            first_stride = (1,1)
            out_num_filters = input_num_filters

        stack_1 = batch_norm(ConvLayer(l, num_filters=out_num_filters, filter_size=(3,3), stride=first_stride, nonlinearity=rectify, pad='same', W=lasagne.init.HeNormal(gain='relu'), flip_filters=False))
        stack_2 = batch_norm(ConvLayer(stack_1, num_filters=out_num_filters, filter_size=(3,3), stride=(1,1), nonlinearity=None, pad='same', W=lasagne.init.HeNormal(gain='relu'), flip_filters=False))
        
        # add shortcut connections
        if increase_dim:
            if projection:
                # projection shortcut, as option B in paper
                projection = batch_norm(ConvLayer(l, num_filters=out_num_filters, filter_size=(1,1), stride=(2,2), nonlinearity=None, pad='same', b=None, flip_filters=False))
                block = NonlinearityLayer(ElemwiseSumLayer([stack_2, projection]),nonlinearity=rectify)
            else:
                # identity shortcut, as option A in paper
                identity = ExpressionLayer(l, lambda X: X[:, :, ::2, ::2], lambda s: (s[0], s[1], s[2]//2, s[3]//2))
                padding = PadLayer(identity, [out_num_filters//4,0,0], batch_ndim=1)
                block = NonlinearityLayer(ElemwiseSumLayer([stack_2, padding]),nonlinearity=rectify)
        else:
            block = NonlinearityLayer(ElemwiseSumLayer([stack_2, l]),nonlinearity=rectify)
        
        return block

    # Building the network
    l_in = InputLayer(shape=(None, 3, 32, 32), input_var=input_var)

    # first layer, output is 16 x 32 x 32
    l = batch_norm(ConvLayer(l_in, num_filters=16, filter_size=(3,3), stride=(1,1), nonlinearity=rectify, pad='same', W=lasagne.init.HeNormal(gain='relu'), flip_filters=False))
    
    # first stack of residual blocks, output is 16 x 32 x 32
    for _ in range(n):
        l = residual_block(l)

    # second stack of residual blocks, output is 32 x 16 x 16
    l = residual_block(l, increase_dim=True)
    for _ in range(1,n):
        l = residual_block(l)

    # third stack of residual blocks, output is 64 x 8 x 8
    l = residual_block(l, increase_dim=True)
    for _ in range(1,n):
        l = residual_block(l)
    
    # average pooling
    l = GlobalPoolLayer(l)

    # fully connected layer
    network = DenseLayer(
            l, num_units=10,
            W=lasagne.init.HeNormal(),
            nonlinearity=softmax)

    return network

# ############################# Batch iterator ###############################

def iterate_minibatches(inputs, targets, batchsize, shuffle=False, augment=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        if augment:
            # as in paper : 
            # pad feature arrays with 4 pixels on each side
            # and do random cropping of 32x32
            padded = np.pad(inputs[excerpt],((0,0),(0,0),(4,4),(4,4)),mode='constant')
            random_cropped = np.zeros(inputs[excerpt].shape, dtype=np.float32)
            crops = np.random.random_integers(0,high=8,size=(batchsize,2))
            for r in range(batchsize):
                random_cropped[r,:,:,:] = padded[r,:,crops[r,0]:(crops[r,0]+32),crops[r,1]:(crops[r,1]+32)]
            inp_exc = random_cropped
        else:
            inp_exc = inputs[excerpt]

        yield inp_exc, targets[excerpt]

# ############################## Main program ################################

def main(num_epochs=82, update="momentum", lr=0.1, reg_factor=0.0001):
    tr_obj = np.zeros(num_epochs)
    te_obj = np.zeros(num_epochs)
    tr_acc = np.zeros(num_epochs)
    te_acc = np.zeros(num_epochs)
    times = np.zeros(num_epochs)

    # Check if cifar data exists
    if not os.path.exists("./cifar-10-batches-py"):
        print("CIFAR-10 dataset can not be found. Please download the dataset from 'https://www.cs.toronto.edu/~kriz/cifar.html'.")
        return

    # Load the dataset
    print("Loading data...")
    data = load_data()
    X_train = data['X_train']
    Y_train = data['Y_train']
    X_test = data['X_test']
    Y_test = data['Y_test']

    # Prepare Theano variables for inputs and targets
    input_var = T.tensor4('inputs')
    target_var = T.ivector('targets')

    # Create neural network model
    print("Building model and compiling functions...")
    network = build_cnn(input_var)
    print("number of parameters in model: %d" % lasagne.layers.count_params(network, trainable=True))

    # Create a loss expression for training, i.e., a scalar objective we want
    # to minimize (for our multi-class problem, it is the cross-entropy loss):

    def groupl1(x): # group lasso
        dims = x.get_value(borrow=True).shape
        if len(dims) == 4:
            return T.sum(T.sqrt(x.shape[1]*x.shape[2]*x.shape[3])*T.sqrt(T.sum(x**2, axis=(1,2,3))))
        elif len(dims) == 1:
            return T.sum(T.sqrt(T.sum(x**2)))
        else:
            return T.sum(T.sqrt(x.shape[1])*T.sqrt(T.sum(x**2, axis=1)))

    prediction = lasagne.layers.get_output(network)
    loss = lasagne.objectives.categorical_crossentropy(prediction, target_var)

    all_layers = lasagne.layers.get_all_layers(network)
    regularizer = lambda x:lasagne.regularization.l1(x)+groupl1(x)

    loss_unreg = loss.mean()
    loss_reg = loss.mean() + reg_factor * lasagne.regularization.regularize_network_params(network, regularizer)
        
    # Create update expressions for training
    # Stochastic Gradient Descent (SGD) with momentum
    params = lasagne.layers.get_all_params(network, trainable=True)
    sh_lr = theano.shared(lasagne.utils.floatX(lr))

    updates_reg = {
        "sgd": lasagne.updates.sgd(loss_reg, params, learning_rate=sh_lr),
        "adagrad": lasagne.updates.adagrad(loss_reg, params, learning_rate=sh_lr),
        "momentum": lasagne.updates.momentum(loss_reg, params, learning_rate=sh_lr, momentum=0.9),
        "tos": adaptos_optimizer.tos(loss_unreg, params, learning_rate=sh_lr, mu=reg_factor),
        "adaptos": adaptos_optimizer.adaptos_with_momentum(loss_unreg, params, learning_rate=sh_lr, mu=reg_factor, momentum=0.9),
    }

    # Compile a function performing a training step on a mini-batch (by giving
    # the updates dictionary) and returning the corresponding training loss:
    train_fn = theano.function([input_var, target_var], loss_reg, updates=updates_reg[update])
    train_acc = T.mean(T.eq(T.argmax(prediction, axis=1), target_var), dtype=theano.config.floatX)

    # Create a loss expression for validation/testing
    test_prediction = lasagne.layers.get_output(network, deterministic=True)
    test_loss = lasagne.objectives.categorical_crossentropy(test_prediction, target_var)
    test_loss = test_loss.mean()
    test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var), dtype=theano.config.floatX)

    # Compile a second function computing the validation loss and accuracy:
    val_fn = theano.function([input_var, target_var], [test_loss, test_acc])

    # launch the training loop
    print("Starting training...")
    # We iterate over epochs:
    for epoch in range(num_epochs):
        # shuffle training data
        train_indices = np.arange(100000)
        np.random.shuffle(train_indices)
        X_train = X_train[train_indices,:,:,:]
        Y_train = Y_train[train_indices]

        # In each epoch, we do a full pass over the training data:
        train_err = 0
        train_acc = 0
        train_batches = 0
        start_time = time.time()

        for batch in iterate_minibatches(X_train, Y_train, 128, shuffle=True, augment=True):
            inputs, targets = batch
            train_err += train_fn(inputs, targets)
            train_acc += train_acc_fn(inputs, targets)
            train_batches += 1

        # And a full pass over the validation data:
        val_err = 0
        val_acc = 0
        val_batches = 0
        for batch in iterate_minibatches(X_test, Y_test, 500, shuffle=False):
            inputs, targets = batch
            err, acc = val_fn(inputs, targets)
            val_err += err
            val_acc += acc
            val_batches += 1

        tr_obj[epoch] = train_err / train_batches
        te_obj[epoch] = val_err / val_batches
        tr_acc[epoch] = train_acc / train_batches * 100
        te_acc[epoch] = val_acc / val_batches * 100
        times[epoch] = time.time() - start_time
        print(tr_obj, te_obj, tr_acc, te_acc)

        # Then we print the results for this epoch:
        print("Epoch {} of {} took {:.3f}s".format(
            epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss:\t\t{:.6f}".format(train_err / train_batches))
        print("  validation loss:\t\t{:.6f}".format(val_err / val_batches))
        print("  validation accuracy:\t\t{:.2f} %".format(
            val_acc / val_batches * 100))

        # adjust learning rate as in paper
        # 32k and 48k iterations should be roughly equivalent to 41 and 61 epochs
        if (epoch+1) == 41 or (epoch+1) == 61:
            new_lr = sh_lr.get_value() * 0.1
            print("New LR:"+str(new_lr))
            sh_lr.set_value(lasagne.utils.floatX(new_lr))

    params_trained = lasagne.layers.get_all_param_values(network, trainable=True)

    data = {}
    data["tr_obj"] = tr_obj
    data["te_obj"] = te_obj
    data["tr_acc"] = tr_acc
    data["te_acc"] = te_acc
    data["times"] = times
    data["params"] = params_trained
    save_file = f"runs2/cifar_{update}_{lr}_{reg_factor}.pkl"
    with open(save_file, 'wb') as f:
        pickle.dump(data, f)

import argparse
parser = argparse.ArgumentParser(description='Experiments for Group Lasso')
parser.add_argument('--num_epochs', required=True, type=int, help="num epochs")
parser.add_argument('--reg_factor', required=True, type=float, help="reg factor")
parser.add_argument('--lr', required=True, type=float, help="learning rate")
parser.add_argument('--optimizer', required=True, choices=["momentum", "sgd", "adagrad", "adaptos", "tos"], help="optimizer to use (sgd, adagrad, adam, adaptos, tos)")
args = parser.parse_args()

num_epochs = args.num_epochs
update = args.optimizer
lr = args.lr
reg_factor = args.reg_factor

if __name__ == '__main__':
    main(num_epochs=int(num_epochs), update=update, lr=lr, reg_factor=reg_factor)
