import time
import tensorflow as tf
import tensorflow_quantum as tfq
import argparse

from constants import N_QUBITS, N_LAYERS, SEED, BATCH_SIZE, EPOCHS, EXTRA_COMPRESSION_FACTOR
from data_processing import  get_quantum_tensors
from gates import create_model, accuracy

def get_args():
    parser = argparse.ArgumentParser(description='quantum neural network')
    parser.add_argument('--nlayers', '-l', action='store', type=int, required=False, default=N_LAYERS, help='')
    parser.add_argument('--nqubits', '-q', action='store', type=int, required=False, default=N_QUBITS, help='')
    parser.add_argument('--nepochs', '-e', action='store', type=int, required=False, default=EPOCHS, help='')
    parser.add_argument('--batch_size', '-b', action='store', type=int, required=False, default=BATCH_SIZE, help='')
    parser.add_argument('--seed', '-d', action='store', type=int, required=False, default=SEED, help='')
    parser.add_argument('--splits', '-s', action='store', type=int, required=False, default=None, help='')
    parser.add_argument('--nsplit', '-i', action='store', type=int, required=False, default=None, help='')
    parser.add_argument('--load_tensors', '-r', action='store_true', required=False, default=False, help='')
    parser.add_argument('--save_tensors', '-w', action='store_true', required=False, default=False, help='')
    parser.add_argument('--parallel', '-p', action='store_true', required=False, default=False, help='')
    parser.add_argument('--extra_compression_factor', '-x', action='store', type=int, required=False,
                        default=EXTRA_COMPRESSION_FACTOR, help='')
    parser.add_argument('--network_architecture', '-a', action='store', type=str, required=False, default='CRADL', 
                        help='CRADL, CRAML or CRADML')
    parser.add_argument('--fold', '-f', action='store', type=int, required=False, default=None, help='')
    parser.add_argument('--folds', '-k', action='store', type=int, required=False, default=None, help='')
    
    return parser.parse_args()


def run(nqubits,
        nlayers,
        splits,
        nsplit,
        load_tensors,
        save_tensors,
        parallel,
        extra_compression_factor,
        seed,
        batch_size,
        nepochs,
        network_type,
        fold_of_folds):

    tf.random.set_seed(seed)

    model_circuit, model_readout = create_model(n_qubits=nqubits, n_layers=nlayers, network_type=network_type)
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(), dtype=tf.string),
        tfq.layers.PQC(model_circuit, model_readout),
    ])

    x_train, y_train, x_test, y_test = \
        get_quantum_tensors(nqubits=nqubits,
            splits=splits,
            load_tensors=load_tensors,
            save_tensors=save_tensors,
            nsplit=nsplit,
            parallel=parallel,
            extra_compression_factor=extra_compression_factor,
            fold_of_folds=fold_of_folds)

    if save_tensors and splits is not None:
        import sys
        sys.exit()

    model.compile(
        loss=tf.keras.losses.Hinge(),
        optimizer=tf.keras.optimizers.Adam(),
        metrics=[accuracy],
    )

    print(model.summary())
    t1 = time.time()
    _qnn_history = model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=nepochs,
        verbose=1,
        validation_data=(x_test, y_test),
    )

    print('time to fit model : {}'.format(time.time() - t1))
    qnn_results = model.evaluate(x_test, y_test)
    return list(zip(model.metrics_names, qnn_results))


if __name__ == '__main__':
    args = get_args()
    run(nqubits=args.nqubits,
        nlayers=args.nlayers,
        splits=args.splits,
        nsplit=args.nsplit,
        load_tensors=args.load_tensors,
        save_tensors=args.save_tensors,
        parallel=args.parallel,
        extra_compression_factor=args.extra_compression_factor,
        seed=args.seed,
        batch_size=args.batch_size,
        nepochs=args.nepochs,
        network_type=args.network_architecture,
        fold_of_folds=(args.fold, args.folds) if args.fold is not None and args.folds is not None else None)

