# tf.compat.v1.disable_eager_execution()
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.backend import image_data_format
assert(image_data_format() == 'channels_last')
import numpy as np
import argparse, os
import yaml
from utils import MeanIoUIgnore, SparseCategoricalAccuracyIgnore, SparseCategoricalCrossentropyIgnore
from callbacks import TensorBoardExt
from models.model_res_tiny import IterNetwork

# import tensorflow as tf
# from tensorflow import debug as tf_debug
# sess = K.get_session()
# sess = tf_debug.LocalCLIDebugWrapperSession(sess)
# K.set_session(sess)

# gpuList = tf.config.experimental.list_physical_devices('GPU')
# for gpu in gpuList:
  # tf.config.experimental.set_memory_growth(gpu, True)

class YamlArgumentParser(argparse.ArgumentParser):
    def convert_arg_line_to_args(self, arg_line):
        yaml_dict = yaml.safe_load(arg_line)
        arg = [['--' + str(key), str(value)] for key, value in yaml_dict.items()]
        assert(len(arg) == 1)
        arg = arg[0]
        if 'False' in arg[1] or 'false' in arg[1]:
            raise RuntimeError('remove argument ' + arg[0] + ' to set it to false')
        elif 'True' in arg[1] or 'true' in arg[1]:
            arg.pop()
        else: # to handle lists
            arg = [arg[0]] + arg[1].split(' ')
        return arg

parser = YamlArgumentParser(fromfile_prefix_chars='@')
parser.add_argument('-d', '--data_dir', type=str, default='/fs/scratch/')
parser.add_argument('-l', '--log_dir', type=str, default='/home//exp/interactive')
parser.add_argument('-ds', '--dataset', type=str, default='')
parser.add_argument('-ct', '--conv_type', type=str, default='mobileV2')
parser.add_argument('-o', '--optimizer', type=str, default='sgdmom')
parser.add_argument('-lr', '--learning_rate', type=float, default=0)
parser.add_argument('-wd', '--weight_decay', type=float, default=0)
parser.add_argument('-dr', '--dropout_rate', type=float, default=0)
parser.add_argument('-nl', '--num_layers', type=int, default=0)
parser.add_argument('-nf', '--num_features', type=int, default=0)
parser.add_argument('-nb', '--num_blocks', type=int, default=0)
parser.add_argument('-nr', '--num_repeats', type=int, default=0)
parser.add_argument('-s', '--seed', type=int, default=42)
parser.add_argument('-ne', '--num_epochs', type=int, default=0)
parser.add_argument('-p', '--patience', type=int, default=0)
parser.add_argument('-dns', '--downsample', action='store_true')
parser.add_argument('-gp', '--global_pool', action='store_true')
parser.add_argument('-sp', '--share_parameters', action='store_true')
parser.add_argument('-bs', '--batch_size', type=int, default=0)
parser.add_argument('-bm', '--batch_mode', action='store_true')
parser.add_argument('-mtv', '--merge_train_valid', action='store_true')
parser.add_argument('-wc', '--weight_classes', action='store_true')
parser.add_argument('-lw', '--loss_weights', nargs='+', type=float)
params = dict(parser.parse_args().__dict__) #FIXME: do not use __dict__

if params['batch_mode']:
    verbose = 2
else:
    verbose = 1

filenameConfig = 'config.yml'

if params['dataset'] == 'camvid':
    dataExt = '360x480'
    params['pathDataTrain'] = os.path.join(params['data_dir'], 'Data', 'camvid', 'camvid-' + dataExt + '-train.npz')
    params['pathDataValid'] = os.path.join(params['data_dir'], 'Data', 'camvid', 'camvid-' + dataExt + '-val.npz')
    params['pathDataTest'] = os.path.join(params['data_dir'], 'Data', 'camvid', 'camvid-' + dataExt + '-test.npz')
    params['crop'] = [352, 480]
elif params['dataset'] == 'cityscapes':
    params['pathData'] = '/fs/scratch//OpenData/cityscapes'
    params['crop'] = [512, 1024]
elif params['dataset'] == 'mnist':
    raise NotImplementedError
    dataExt = '_small'
    # dataExt = ''
    params['pathDataTrain'] = os.path.join(params['data_dir'], 'Data', 'mnist_semseg', 'train' + dataExt + '.npz')
    params['pathDataValid'] = os.path.join(params['data_dir'], 'Data', 'mnist_semseg', 'valid' + dataExt + '.npz')
    params['pathDataTest'] = os.path.join(params['data_dir'], 'Data', 'mnist_semseg', 'test' + dataExt + '.npz')
    params['crop'] = [64, 64] # TODO: check
    params['numClasses'] = 10
elif params['dataset'] == 'cifar10':
    params['pathData'] = os.path.join(params['data_dir'], 'OpenData', 'cifar', 'cifar-10-batches-py')
else:
    raise NotImplementedError

if params['loss_weights'] is None:
    weights = np.ones(params['num_layers'], dtype=np.float64)
    params['loss_weights'] = [float(x) for x in weights]
    weights = weights / np.sum(weights)
    print('uniformly weighting losses:', weights)
else:
    weights = params['loss_weights']
    if len(weights) != params['num_layers']:
        raise RuntimeError('provide as many weights for the loss as there are losses in the network: is:' + str(len(weights)) + ' - should:' + str(params['num_layers']))
    weights = weights / np.sum(weights)
    print('custom weighting of losses:', weights)
lossWeights = {}
for index_layer in range(params['num_layers']):
    name = 'output_' + str(index_layer + 1) # name is defined by tf.keras modules
    lossWeights[name] = weights[index_layer]


np.random.seed(params['seed'])
tf.random.set_seed(params['seed'])

# DATASETS
# TODO: unify loading of datasets
if params['dataset'] == 'camvid':
    from data.datasets import Camvid
    dataset = Camvid(params['pathDataTrain'], params['pathDataValid'], pathTest=params['pathDataTest'],
            num_layers=params['num_layers'], batch_size=params['batch_size'], crop=params['crop'], merge_train_valid=params['merge_train_valid'])
    dataTrain = dataset.get_train_set()
    dataValid = dataset.get_valid_set()
    dataTest = dataset.get_test_set()
    inputShape = dataset.get_shape()
    params['numClasses'] = int(dataset.get_num_classes())
elif params['dataset'] == 'mnist': # TODO
    raise NotImplementedError
elif params['dataset'] == 'cityscapes':
    from data.datasets import Cityscapes
    dataset = Cityscapes(params['pathData'],
            num_layers=params['num_layers'], batch_size=params['batch_size'], crop=params['crop'], merge_train_valid=params['merge_train_valid'])
    dataTrain = dataset.get_train_set()
    dataValid = dataset.get_valid_set()
    dataTest = dataset.get_test_set()
    inputShape = dataset.get_shape()
    params['numClasses'] = int(dataset.get_num_classes())
elif params['dataset'] == 'cifar10':
    from data.datasets import Cifar10
    dataset = Cifar10(params['pathData'],
            num_layers=params['num_layers'], batch_size=params['batch_size'], merge_train_valid=params['merge_train_valid'])
    dataTrain = dataset.get_train_set()
    dataValid = dataset.get_valid_set()
    inputShape = dataset.get_shape()
    params['numClasses'] = int(dataset.get_num_classes())
else:
    raise NotImplementedError
params['input_shape'] = list(inputShape)

os.makedirs(params['log_dir'], exist_ok=True)
with open(os.path.join(params['log_dir'], filenameConfig), 'w') as configFile:
    yaml.dump(params, configFile)

# MODEL
# TODO: use single dict for this call?
model = IterNetwork(num_classes=params['numClasses'],
        num_blocks=params['num_blocks'], num_iterations=params['num_layers'], num_feat_base=params['num_features'], num_repeats=params['num_repeats'],
        dropout_rate=params['dropout_rate'], weight_decay=params['weight_decay'],
        downsample=params['downsample'], global_pool=params['global_pool'], conv_type=params['conv_type'], share_parameters=params['share_parameters'], shuffle=False)
model.build((None,) + inputShape)
model.summary()
# numParams = int(np.sum([K.count_params(kernel) for kernel in set(model.trainable_weights)]))
numParams = 0

# TRAINING
if params['optimizer'] == 'adam':
    myOptimizer = keras.optimizers.Adam(learning_rate=0.0)
elif params['optimizer'] == 'rmsprop':
    myOptimizer = keras.optimizers.RMSprop(learning_rate=0.0)
elif params['optimizer'] == 'cosine':
    # NOTE: the below code does not work, TF bug?
    # learningRate = keras.experimental.CosineDecay(params['learning_rate'], params['num_epochs'])
    # myOptimizer = keras.optimizers.SGD(learningRate, momentum=0.9, nesterov=True)
    myOptimizer = keras.optimizers.SGD(learning_rate=0.0, momentum=0.9, nesterov=True)
elif params['optimizer'] == 'sgdmom':
    myOptimizer = keras.optimizers.SGD(learning_rate=0.0, momentum=0.9, nesterov=True)
else:
    raise NotImplementedError
if params['dataset'] in ['camvid', 'cityscapes', 'mnist']:
    metricsList = [SparseCategoricalAccuracyIgnore(), MeanIoUIgnore(params['numClasses'])]
    if params['dataset'] == 'mnist' or (not params['weight_classes']):
        classWeight = None
    else:
        with open(os.path.join('data', 'class_balance.yml'), 'r') as myFile:
            classWeight = yaml.safe_load(myFile)[params['dataset']]
            classWeight = np.array([classWeight[classIndex] for classIndex in range(len(classWeight))])
    loss = SparseCategoricalCrossentropyIgnore(class_weight=classWeight, from_logits=True)
elif params['dataset'] in ['cifar10']:
    metricsList = [tf.keras.metrics.SparseCategoricalAccuracy()]
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
else:
    raise NotImplementedError

# TODO: use weighting of loss for single loss networks
model.compile(optimizer=myOptimizer, loss=loss, metrics=metricsList, loss_weights=lossWeights, run_eagerly=True)

tb = keras.callbacks.TensorBoard(log_dir=params['log_dir'], histogram_freq=0, profile_batch=0)
tbExt = TensorBoardExt(log_dir=params['log_dir'], histogram_freq=0, profile_batch=0)
checkpointFile = os.path.join(params['log_dir'], 'model')
name_last_output = 'val_output_' + str(params['num_layers']) + '_miou_ignore'
checkpoint = keras.callbacks.ModelCheckpoint(checkpointFile, name_last_output, verbose=1, save_best_only=True, save_weights_only=True, mode='max')
fileWriter = tf.summary.create_file_writer(os.path.join(params['log_dir'], 'metrics'))
fileWriter.set_as_default()
def lr_schedule(epochIndex):
    if epochIndex < params['num_epochs'] * 0.7:
        learning_rate = params['learning_rate']
    elif epochIndex < params['num_epochs'] * 0.85:
        learning_rate = params['learning_rate'] * 0.1
    elif epochIndex < params['num_epochs']:
        learning_rate = params['learning_rate'] * 0.01
    tf.summary.scalar('lr', data=learning_rate, step=epochIndex)
    return learning_rate
learning_rate_steps = keras.callbacks.LearningRateScheduler(lr_schedule)
def cosine_schedule(epochIndex, alpha=0): # inspired by https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/cosine_decay
    epochIndex = min(epochIndex, params['num_epochs'])
    cosine_decay = 0.5 * (1.0 + np.cos(np.pi * epochIndex / params['num_epochs']))
    decayed = (1.0 - alpha) * cosine_decay + alpha
    learning_rate = params['learning_rate'] * decayed
    tf.summary.scalar('lr', data=learning_rate, step=epochIndex)
    return learning_rate
learning_rate_cosine = keras.callbacks.LearningRateScheduler(cosine_schedule)
def cosine_schedule_builtin(epochIndex, alpha=0): # inspired by https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/cosine_decay
    learning_rate = tf.compat.v1.train.cosine_decay(params['learning_rate'], epochIndex, params['num_epochs'], alpha=alpha)()
    # num_cycles = 4
    # assert(params['num_epochs'] % num_cycles == 0)
    # learning_rate = tf.compat.v1.train.cosine_decay_restarts(params['learning_rate'], epochIndex, params['num_epochs'] // num_cycles, t_mul=1.0, alpha=alpha)()
    tf.summary.scalar('lr', data=learning_rate, step=epochIndex)
    return learning_rate
learning_rate_cosine_builtin = keras.callbacks.LearningRateScheduler(cosine_schedule_builtin)
# learning_rate_steps = keras.callbacks.ReduceLROnPlateau('val_loss', verbose=1, patience=params['patience'])
# stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=params['patience'] * 2, verbose=1)
if params['optimizer'] == 'cosine':
    # callbacksList = [tbExt, checkpoint]
    callbacksList = [tb, checkpoint, learning_rate_cosine_builtin]
else:
    callbacksList = [tb, checkpoint, learning_rate_steps]

if params['merge_train_valid']:
    validation_data = dataTest
else:
    validation_data = dataValid
history = model.fit(dataTrain, validation_data=validation_data,
        epochs=params['num_epochs'], callbacks=callbacksList, verbose=verbose)

checkpointFileLast = os.path.join(params['log_dir'], 'model_last')
model.save_weights(checkpointFileLast)

results = history.history
if 'lr' in results:
    results['lr'] = [float(x) for x in results['lr']]
with open(os.path.join(params['log_dir'], 'history.yml'), 'w') as myFile:
    yaml.dump(results, myFile)

def evaluate_model(filenameModel):
    model.load_weights(filenameModel)
    eval_train = model.evaluate(dataTrain, return_dict=True) # includes data augmentation
    if not params['merge_train_valid']:
        eval_valid = model.evaluate(dataValid, return_dict=True)
    else:
        eval_valid = None
    eval_test = model.evaluate(dataTest, return_dict=True)

    return {'train': eval_train, 'valid': eval_valid, 'test': eval_test}

miou_best = evaluate_model(checkpointFile)
miou_last = evaluate_model(checkpointFileLast)
with open(os.path.join(params['log_dir'], 'eval.yml'), 'w') as myFile:
    yaml.dump({'best': miou_best, 'last': miou_last}, myFile)
