# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Train networks """
import os, gc
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
import tensorflow as tf
import objax

# custom
from utils.datasets import load_dataset, load_test_batch, do_augmentation
from utils.models import load_network, save_network_parameters
from utils.optimizers import make_optimizer
from utils.learner import train, valid


# Globals
_best_acc   = 0.
_best_loss  = 100.


# Common configurations
_seed       = 215
_dataset    = 'cifar10'
_store_name = 'base'


# MNIST
if 'mnist' == _dataset:
    # (Feedforward Network)
    _network    = 'FFNet'
    _num_batchs = 64
    _num_epochs = 20
    _optimizer  = 'SGD'
    _learn_rate = 0.1
    _decay_rate = 0.0
    _schedulelr = [10]
    _schedratio = 0.5

# SVHN
elif 'svhn' == _dataset:
    # (Feedforward Network)
    # _network    = 'FFNet'
    # _num_batchs = 50
    # _num_epochs = 150
    # _optimizer  = 'SGD'
    # _learn_rate = 0.02
    # _decay_rate = 0.0
    # _schedulelr = [60]
    # _schedratio = 0.5

    # (ConvNet)
    _network    = 'ConvNet'
    _num_batchs = 50
    _num_epochs = 100
    _optimizer  = 'SGD'
    _learn_rate = 0.02
    _decay_rate = 0.0
    _schedulelr = [40]
    _schedratio = 0.5

# CIFAR-10
elif 'cifar10' == _dataset:
    # (VGG16)
    # _network    = 'VGG16'
    # _num_batchs = 128
    # _num_epochs = 100
    # _optimizer  = 'Momentum'
    # _learn_rate = 0.01
    # _decay_rate = 0.0005
    # _schedulelr = [40, 80]
    # _schedratio = 0.4

    # (ConvNet)
    # _network    = 'ConvNet'
    # _num_batchs = 64
    # _num_epochs = 100
    # _optimizer  = 'Momentum'
    # _learn_rate = 0.03
    # _decay_rate = 0.0005
    # _schedulelr = [40, 80]
    # _schedratio = 0.1

    # (ResNet18)
    _network    = 'ResNet18'
    _num_batchs = 64
    _num_epochs = 100
    _optimizer  = 'Momentum'
    _learn_rate = 0.03
    _decay_rate = 0.0005
    _schedulelr = [40, 80]
    _schedratio = 0.1


# PubFig (publig figure - face dataset)
elif 'pubfig' == _dataset:
    # (Pretrained VGG-Face)
    # _network    = 'VGGFace'
    # _num_batchs = 50
    # _num_epochs = 20
    # _optimizer  = 'Momentum'
    # _learn_rate = 0.001
    # _decay_rate = 0.0
    # _schedulelr = [10]
    # _schedratio = 0.5

    # (Pretrained InceptionResNetV1)
    _network    = 'InceptionResNetV1'
    _num_batchs = 50
    _num_epochs = 10
    _optimizer  = 'Momentum'
    _learn_rate = 0.001
    _decay_rate = 0.0
    _schedulelr = [10]
    _schedratio = 0.5



"""
    Main code starts from here
"""
if __name__ == '__main__':

    # Set the random seed (for the reproducible experiments)
    np.random.seed(_seed)

    # Data
    (X_train, Y_train), (X_test, Y_test) = load_dataset(_dataset)
    print (' : Use the dataset - {}'.format(_dataset))

    # Set the pretrained flags
    set_pretrain = True if _dataset in ['pubfig'] else False

    # Model
    model = load_network(_dataset, _network, use_pretrain=set_pretrain)
    print (' : Use the network - {}, [pretrained: {}]'.format(type(model).__name__, set_pretrain))
    # print (model.vars())

    # train variables
    train_vars = model.vars()
    if _network in ['VGGFace']:
        train_vars = objax.VarCollection((k, v) \
            for i, (k, v) in enumerate(model.vars().items()) if i > 23) # last 4-layers (conv + 3-fc)


    # objective function
    predict = objax.Jit(lambda x: model(x, training=False), train_vars)

    # set the model savefile
    storedir = os.path.join('models', _dataset, _network)
    if not os.path.exists(storedir): os.makedirs(storedir)
    storefile = os.path.join(storedir, 'best_model_{}.npz'.format(_store_name))

    # losses
    def loss(x, label):
        logit = model(x, training=True)
        loss_xe = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
        loss_l2 = 0.5 * sum((v.value ** 2).sum() for k, v in train_vars.items() if k.endswith('.w'))
        return loss_xe + _decay_rate * loss_l2


    gv  = objax.GradValues(loss, train_vars)
    opt = make_optimizer(train_vars, _optimizer)
    print (' : Use the optimizer - {}'.format(type(opt).__name__))


    def train_op(x, y, lr):
        g, v = gv(x, y)
        opt(lr=lr, grads=g)
        return v


    # gv.vars() contains the model variables.
    train_op = objax.Jit(train_op, gv.vars() + opt.vars())


    # local data-loader
    learning_rate = _learn_rate


    # Training
    for epoch in range(_num_epochs):
        # : update the learning rate
        if epoch in _schedulelr:
            learning_rate = learning_rate * _schedratio
            print (' : Update the learning rate to [{:.4f}]'.format(learning_rate))

        # : train one epoch
        clean_loss = train(epoch, X_train, Y_train, _num_batchs, train_op, learning_rate, augment=do_augmentation)

        # : run eval
        clean_batch = load_test_batch(_dataset)
        clean_acc   = valid(epoch, X_test, Y_test, clean_batch, predict)
        print(' : Loss [train %.4f]  Accuracy [clean %.2f]' % (clean_loss, clean_acc))

        # Store the model
        if clean_acc > _best_acc:
            _best_acc = clean_acc
            save_network_parameters(model, storefile)
            print (' : Store the model, to [{}]'.format(storefile))

    print (' : Done')
    # Fin.
