# pylint: disable = line-too-long
import sys, os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'  # reduce TF verbosity TODO: not on module level..? :/
import tensorflow as tf

from deepltl.train.common import *
from deepltl.utils import utils
from deepltl.optimization import lr_schedules
from deepltl.models import transformer
from deepltl.data import vocabulary
from deepltl.data import datasets


def run():
    # Argument parsing
    parser = argparser()
    # add specific arguments
    parser.add_argument('--problem', type=str, default='ltl', help='available problems: ltl, sat')
    parser.add_argument('--d-embed-enc', type=int, default=128)
    parser.add_argument('--d-embed-dec', type=int, default=None)
    parser.add_argument('--d-ff', type=int, default=512)
    parser.add_argument('--ff-activation', default='relu')
    parser.add_argument('--num-heads', type=int, default=4)
    parser.add_argument('--num-layers', type=int, default=4)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--warmup-steps', type=int, default=4000)
    parser.add_argument('--pos-enc', type=str, default=None, help='available tree positional encodings: tree-branch-up, tree-branch-down')
    parser.add_argument('--format', type=str, default=None, help='format of formulas, needs to be specified if tree positional encoding is used')
    parser.add_argument('--force-load', default=False, action='store_true', help='Assure that weigths from checkpoint are loaded, fail otherwise')
    params = parser.parse_args()
    setup(**vars(params))

    # Dataset specification
    if params.ds_name is None:
        dataset_name = utils.dataset_name(num_aps=5, tree_size=35, num_formulas=1000000)
        params.ds_name = dataset_name
    else:
        dataset_name = params.ds_name
    aps = ['a', 'b', 'c', 'd', 'e', 'f']
    consts = ['0', '1']
    if params.problem == 'ltl':
        input_vocab = vocabulary.LTLVocabulary(aps=aps, consts=consts, ops=['U', 'X', '!', '&', 'F', 'G', 'W', '|', '>'], eos=params.pos_enc is None)
        target_vocab = vocabulary.TraceVocabulary(aps=aps, consts=consts, ops=['&', '|', '!'])
        dataset = datasets.LTLTracesDataset(dataset_name, input_vocab, target_vocab, data_dir=params.data_dir, format=params.format)
        params.max_encode_length = 256 # use less for shorter training data
        params.max_decode_length = 80 # use less for shorter training data
    elif params.problem == 'sat':
        input_vocab = vocabulary.LTLVocabulary(aps, consts, ['!', '&', '|', '<->', 'xor'], eos=params.pos_enc is None)
        target_vocab = vocabulary.TraceVocabulary(aps, consts, [], special=[])
        dataset = datasets.BooleanSatDataset(dataset_name, data_dir=params.data_dir, formula_vocab=input_vocab, assignment_vocab=target_vocab)
        params.max_encode_length = 37
        params.max_decode_length = 12
    else:
        print(f'{params.problem} is not a valid problem\n')
        return

    params.input_vocab_size = input_vocab.vocab_size()
    params.input_pad_id = input_vocab.pad_id
    params.target_vocab_size = target_vocab.vocab_size()
    params.target_start_id = target_vocab.start_id
    params.target_eos_id = target_vocab.eos_id
    params.target_pad_id = target_vocab.pad_id
    params.dtype = tf.float32

    if params.d_embed_dec is None:
        params.d_embed_dec = params.d_embed_enc
    print('Specified dimension of encoder embedding:', params.d_embed_enc)
    params.d_embed_enc -= params.d_embed_enc % params.num_heads  # round down
    print('Specified dimension of decoder embedding:', params.d_embed_dec)
    params.d_embed_dec -= params.d_embed_dec % params.num_heads  # round down
    print('Parameters:')
    for key, val in vars(params).items():
        print('{:25} : {}'.format(key, val))

    if not params.test:  # train mode
        if params.problem == 'ltl':
            train_dataset, val_dataset, test_dataset = dataset.get_dataset(['train', 'val', 'test'], max_length_formula=params.max_encode_length - 2, max_length_trace=params.max_decode_length - 2, prepend_start_token=False, pos_enc=params.pos_enc)
        if params.problem == 'sat':
            train_dataset, val_dataset, test_dataset = dataset.get_dataset(splits=['train', 'val', 'test'], pos_enc=params.pos_enc)
        train_dataset = prepare_dataset_no_tf(train_dataset, params.batch_size, params.d_embed_enc, shuffle=True, pos_enc=params.pos_enc is not None)
        val_dataset = prepare_dataset_no_tf(val_dataset, params.batch_size, params.d_embed_enc, shuffle=False, pos_enc=params.pos_enc is not None)
    else:  # test mode
        if params.problem == 'ltl':
            test_dataset, = dataset.get_dataset(['test'], max_length_formula=params.max_encode_length - 2, max_length_trace=params.max_decode_length - 2, prepend_start_token=False, pos_enc=params.pos_enc)
        if params.problem == 'sat':
            test_dataset, = dataset.get_dataset(splits=['test'], pos_enc=params.pos_enc)

    if not params.test:  # train mode
        # Model & Training specification
        learning_rate = lr_schedules.TransformerSchedule(params.d_embed_enc, warmup_steps=params.warmup_steps)
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
        model = transformer.create_model(vars(params), training=True, custom_pos_enc=params.pos_enc is not None)
        latest_checkpoint = last_checkpoint(**vars(params))
        if latest_checkpoint:
            model.load_weights(latest_checkpoint).expect_partial()
            print(f'Loaded weights from checkpoint {latest_checkpoint}')
        elif params.force_load:
            sys.exit('Failed to load weights, no checkpoint found!')
        else:
            print('No checkpoint found, creating fresh parameters')
        sys.stdout.flush()

        callbacks = [checkpoint_callback(save_weights_only=True, save_best_only=False, **vars(params)),
                     tensorboard_callback(**vars(params)),
                     tf.keras.callbacks.EarlyStopping('val_accuracy', patience=4, restore_best_weights=True)]
        hypertune_callback = HypertuneCallback('val_accuracy')
        if params.hypertune:
            callbacks.append(hypertune_callback)

        # Train!
        log_params(**vars(params))
        model.compile(optimizer=optimizer)
        try:
            model.fit(train_dataset, epochs=params.epochs, validation_data=val_dataset, validation_freq=1, callbacks=callbacks, initial_epoch=params.initial_epoch, verbose=2, shuffle=False)
        except Exception as e:
            if not params.hypertune:
                raise  # fail with Exception
            print('---- Exception occurred during training ----\n' + str(e))
            hypertune_callback.fail()

    else:  # test mode
        prediction_model = transformer.create_model(vars(params), training=False, custom_pos_enc=params.pos_enc is not None)
        latest_checkpoint = last_checkpoint(**vars(params))
        if latest_checkpoint:
            prediction_model.load_weights(latest_checkpoint).expect_partial()
            print(f'Loaded weights from checkpoint {latest_checkpoint}')
        else:
            sys.exit('Could not load weights from checkpoint')
        sys.stdout.flush()

        padded_shapes = ([None], [None]) if params.pos_enc is None else ([None], [None, params.d_embed_enc], [None])
        test_dataset = test_dataset.shuffle(100000, seed=42).take(30000).padded_batch(params.batch_size, padded_shapes=padded_shapes)

        if params.problem == 'ltl':
            if params.pos_enc is not None:
                def pred_fn(x, pe):
                    output, _ = prediction_model([x, pe], training=False)
                    return output
            else:
                def pred_fn(x):
                    output, _ = prediction_model(x, training=False)
                    return output
            print('Starting check...')
            test_and_analyze_ltl(pred_fn, test_dataset, input_vocab, target_vocab, log_name='test.log', **vars(params))

        if params.problem == 'sat':
            test_and_analyze_sat(prediction_model, test_dataset, input_vocab, target_vocab, log_name='test.log', **vars(params))


if __name__ == '__main__':
    run()
