import time
from functools import partial
from absl import app
from absl import flags
import jax.numpy as np
from jax import random
import neural_tangents as nt
import networks
import datasets

available_train_size = 60000
available_test_size = 10000
device_count = 3
batch_size = int(27 / device_count) * device_count

flags.DEFINE_integer('train_size',
                     int(available_train_size / (batch_size * device_count)) *
                     batch_size * device_count,
                     'Dataset size to use for training.')
flags.DEFINE_integer('test_size', int(available_test_size / (batch_size * device_count)) *
                     batch_size * device_count,
                     'Dataset size to use for testing.')
flags.DEFINE_integer('device_count', device_count,
                     'Device count for kernel computation.')
flags.DEFINE_integer('batch_size', batch_size,
                     'Batch size for kernel computation. 0 for no batching.')
flags.DEFINE_bool('empirical_kernel', False,
                  'If true use an empirical kernel, otherwise use an infinite kernel')
flags.DEFINE_integer('width_factor', 1,
                  'Width factor for neural network, relevant when empirical_kernel=True')
flags.DEFINE_string('kernel_type', 'ntk',
                  'the kernel type can be either ntk or nngp')

FLAGS = flags.FLAGS


def timestr():
    return time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime())


def main(unused_argv):
    print(f'FLAGS={FLAGS}')
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.get_dataset('mnist', FLAGS.train_size, FLAGS.test_size, )

    # Build the infinite network.
    init_fn, apply_fn, kernel_fn = networks.lenet5(FLAGS.width_factor)

    # Generate empirical kernel if need
    if FLAGS.empirical_kernel:
        kernel_fn = nt.empirical_kernel_fn(apply_fn)
        rng = random.PRNGKey(10)
        rng, rng_params = random.split(rng, 2)
        _, params = init_fn(rng_params, (-1, 28, 28, 1))
        kernel_fn = partial(kernel_fn, params=params)

    # Optionally, compute the kernel in batches, in parallel.
    print('compute kernel_fn')
    kernel_fn = nt.batch(kernel_fn,
                         device_count=FLAGS.device_count,
                         batch_size=FLAGS.batch_size,
                         store_on_device=False)

    print('compute k_dd')
    start = time.time()
    k_dd = kernel_fn(x_train, x_train, FLAGS.kernel_type)
    k_dd.block_until_ready()
    duration = time.time() - start
    print(f'k_dd.shape={k_dd.shape}, duration={duration}')
    np.save('k_dd.' + timestr(), k_dd)

    print('compute k_td')
    start = time.time()
    k_td = kernel_fn(x_test, x_train, FLAGS.kernel_type)
    k_td.block_until_ready()
    duration = time.time() - start
    print(f'k_td.shape={k_td.shape}, duration={duration}')
    np.save('k_td.' + timestr(), k_td)

    return 0


if __name__ == '__main__':
    app.run(main)
