from __future__ import absolute_import, division, print_function, \
    unicode_literals

import argparse

from data import get_data_generator
from algorithms import get_algorithm


def main(args):
    # get data
    dg = get_data_generator(args, args.data_name)

    # get algorithm
    alg = get_algorithm(args, dg)

    # compute
    alg.run()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--experiment_id', type=str, default='default',
                        help='Experiment ID.')
    parser.add_argument('--data_random_seed', type=int, default=42,
                        help='Random seed.')
    parser.add_argument('--parameter_random_seed', type=int, default=7,
                        help='Random seed.')
    parser.add_argument('--data_name', type=str, default='color',
                        help='Data name.')
    parser.add_argument('--algorithm', type=str, default='color',
                        help='Loss type.')
    parser.add_argument('--n_hidden_nodes', type=int, default=32,
                        help='Number of nodes in hidden layer.')
    parser.add_argument('--batch_size', type=int, default=500,
                        help='Training batch size.')
    parser.add_argument('--test_size', type=int, default=500,
                        help='Test batch size.')
    parser.add_argument('--steps', type=int, default=1000,
                        help='Training steps.')
    parser.add_argument('--inference_steps', type=int, default=100,
                        help='Inference steps.')
    parser.add_argument('--log_steps', type=int, default=100,
                        help='Log steps.')
    parser.add_argument('--lr', type=float, default=0.001,
                        help='Learning rate.')
    parser.add_argument('--test_learning_rate', type=float, default=0.1,
                        help='Test learning rate.')
    parser.add_argument('--max_gradient_norm', type=float, default=-1,
                        help='Max gradient norm')
    parser.add_argument('--sigma', type=float, default=1.0,
                        help='Entropy noise deviation.')
    parser.add_argument('--alpha', type=float, default=1.0,
                        help='Entropy regularization weight.')
    parser.add_argument('--beta', type=float, default=0.3,
                        help='Reconstruction loss weight.')
    parser.add_argument('--gamma', type=float, default=0.1,
                        help='Manifold regularization loss weight.')
    parser.add_argument('--random_initialize', action='store_true',
                        default=False, help='Randomize hidden reps to infer.')
    parser.add_argument('--use_entropy_regularization_train',
                        action='store_true', default=False,
                        help='Use entropy regularization in training.')
    parser.add_argument('--use_entropy_regularization_inference',
                        action='store_true', default=False,
                        help='Use entropy regularization in inference.')
    parser.add_argument('--use_reconstruction_train', action='store_true',
                        default=False, help='Use reconstruction in train.')
    parser.add_argument('--use_reconstruction_inference', action='store_true',
                        default=False, help='Use reconstruction in inference.')
    parser.add_argument('--use_manifold_regularization', action='store_true',
                        default=False,
                        help='Use manifold regularization in inference.')
    parser.add_argument('--memory_size', type=int, default=500,
                        help='Number of samples in memory.')
    args = parser.parse_args()
    main(args)
