import pickle

import numpy as np
import tensorflow as tf

import tqdm

import absl.app
import absl.flags
from absl import logging

from parameterized_model import ParameterizedModel

from .misc_utils import AttrDict
from .misc_utils import define_flags_with_default, set_random_seed, TensorBoardLogger, print_flags
from .data import datasets
from .model import models

FLAGS = absl.flags.FLAGS


flags_def = define_flags_with_default(
    inner_steps=1,
    inner_lr=3e-2,
    outer_lr=1e-4,
    outer_lr_decay_rate=1.0,
    outer_lr_decay_step=10000000,
    inner_lr_type='fixed',
    inner_lr_scaling='exp',
    inner_lr_scaling_coef=1.0,
    inner_clip_gradient=0.0,
    outer_clip_gradient=0.0,
    meta_train_tasks=5,
    inner_batch_size='10',
    outer_batch_size=256,
    inner_step_likelihood_reg=0.0,
    info_reg=0.0,
    info_reg_init_stddev=1e-4,
    info_reg_min_stddev=1e-6,
    info_reg_max_stddev=1.0,
    weight_decay=0.0,
    n_steps=10000,
    test_interval=1000,
    test_batches=5,
    log_interval=20,
    n_training_tasks=0,
    seed=42,
    max_gpu_mem_frac=1.0,
    gpu_mem_growth=True,
    dataset='mnist',
    data_file='./rainbow_mnist.pkl',
    output_dir='/tmp/vs_maml',
)


def parse_variabe_batch_size(batch_sizes):
    if batch_sizes == '':
        return None
    return [int(x) for x in batch_sizes.split('-')]


class MetaPlaceholder(object):

    def __init__(self):
        self.pre_adapt_images = tf.placeholder(dtype=tf.float32, shape=datasets[FLAGS.dataset].image_batch_shape(None))
        self.pre_adapt_labels = tf.placeholder(dtype=tf.int64, shape=[None])
        self.post_adapt_images = tf.placeholder(dtype=tf.float32, shape=datasets[FLAGS.dataset].image_batch_shape(None))
        self.post_adapt_labels = tf.placeholder(dtype=tf.int64, shape=[None])

    def __call__(self, pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels):
        return {
            self.pre_adapt_images: pre_adapt_images,
            self.pre_adapt_labels: pre_adapt_labels,
            self.post_adapt_images: post_adapt_images,
            self.post_adapt_labels: post_adapt_labels,
        }


class InnerLearningRate(object):

    def __init__(self, init_lr, lr_type='fixed', lr_scaling='exp',
                 lr_scaling_coef=1.0, max_batch_size=20):
        self.init_lr = init_lr
        self.lr_type = lr_type
        self.lr_scaling = lr_scaling
        self.lr_scaling_coef = lr_scaling_coef
        self.max_batch_size = max_batch_size
        self._trainables = []

        if lr_type == 'fixed':
            self.build_fixed_lr()
        elif lr_type == 'table':
            self.build_table_lr()
        elif lr_type in {'manual_fixed_coef', 'manual2_fixed_coef'}:
            self.build_manual_fixed_coef_lr()
        elif lr_type in {'manual_learned_coef', 'manual2_learned_coef'}:
            self.build_manual_learned_coef_lr(
                init=1.1, minimum=1.0
            )
        elif lr_type == 'manual3_learned_coef':
            self.build_manual_learned_coef_lr(
                init=1.0, minimum=1.0
            )
        else:
            raise ValueError('Unsupported learning rate type!')

    @property
    def trainables(self):
        return self._trainables

    def unscaled_init_lr(self):
        if self.lr_scaling == 'exp':
            return np.log(self.init_lr) / self.lr_scaling_coef
        elif self.lr_scaling == 'linear':
            return self.init_lr / self.lr_scaling_coef
        else:
            raise ValueError('Unsupported learning rate scaling!')

    def scale_lr(self, unscaled_lr):
        if self.lr_scaling == 'exp':
            return tf.exp(unscaled_lr * self.lr_scaling_coef)
        elif self.lr_scaling == 'linear':
            return tf.nn.relu(unscaled_lr * self.lr_scaling_coef)
        else:
            raise ValueError('Unsupported learning rate scaling!')

    def build_fixed_lr(self):
        self.lr = tf.constant(self.init_lr, dtype=tf.float32)

    def build_table_lr(self):
        self.lr_variable = tf.Variable(
            tf.ones(self.max_batch_size + 1) * self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.lr = self.scale_lr(self.lr_variable)
        self.trainables.append(self.lr_variable)

    def build_manual_fixed_coef_lr(self):
        self.lr_variable = tf.Variable(
            tf.ones(2) * self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.lr = self.scale_lr(self.lr_variable)
        self.trainables.append(self.lr_variable)

    def build_manual_learned_coef_lr(self, init=1.1, minimum=1.0):
        self.lr_variable = tf.Variable(
            tf.ones(2) * self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.lr = self.scale_lr(self.lr_variable)
        self.trainables.append(self.lr_variable)
        self.manual_coef_var = tf.Variable(
            init - minimum,
            dtype=tf.float32, trainable=True
        )
        self.manual_coef = tf.abs(self.manual_coef_var) + minimum
        self.trainables.append(self.manual_coef_var)

    def get_lr(self, batch):
        if self.lr_type == 'fixed':
            return self.lr
        elif self.lr_type == 'table':
            batch_size = tf.shape(batch)[0]
            return self.lr[tf.minimum(batch_size, self.max_batch_size)]
        elif self.lr_type == 'manual_fixed_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (tf.constant(1.1, dtype=tf.float32) - 1 / batch_size),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual2_fixed_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (tf.constant(1.0, dtype=tf.float32) - 1 / (1 + batch_size)),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (self.manual_coef - 1 / batch_size),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual2_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (self.manual_coef - 1 / (1 + batch_size)),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual3_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (1.0 - 1 / (1 + batch_size * self.manual_coef)),
                lambda :self.lr[0]
            )
        else:
            raise ValueError('Unsupported learning rate type!')


def merge_dict(dicts):
    merged = {}
    for d in dicts:
        merged.update(d)
    return merged


def average_dict(dicts):
    keys = set(dicts[0].keys())
    sum_dict = {key: 0 for key in keys}
    for d in dicts:
        assert set(d.keys()) == keys
        for key in keys:
            sum_dict[key] += d[key]
    for key in keys:
        sum_dict[key] = sum_dict[key] / len(dicts)
    return sum_dict


def kl_qp_gaussian(mu_q, sigma_q, mu_p, sigma_p):
      """Kullback-Leibler KL(N(mu_q), Diag(sigma_q^2) || N(mu_p), Diag(sigma_p^2))"""
      sigma2_q = tf.square(sigma_q) + 1e-16
      sigma2_p = tf.square(sigma_p) + 1e-16
      temp = tf.math.log(sigma2_p) - tf.math.log(sigma2_q) - 1.0 + \
              sigma2_q / sigma2_p + tf.square(mu_q - mu_p) / sigma2_p  #n_target * d_w
      kl = 0.5 * tf.reduce_mean(temp)
      return kl


def build_graph():
    graph = AttrDict()
    pm = ParameterizedModel(name='maml_rainbow_mnist')

    with pm.build_template():
        models[FLAGS.dataset](
            tf.zeros(datasets[FLAGS.dataset].image_batch_shape(1)),
            datasets[FLAGS.dataset].n_logits(), training=True
        )

    graph.param = AttrDict()
    graph.param.parameterized_model = pm
    graph.param.parameter = pm.parameter

    graph.param.trainables = [graph.param.parameter]

    if FLAGS.info_reg > 0:
        graph.param.log_stddev = tf.Variable(
            tf.ones_like(graph.param.parameter) * np.log(FLAGS.info_reg_init_stddev),
            dtype=tf.float32, trainable=True
        )
        graph.param.stddev = tf.clip_by_value(
            tf.exp(graph.param.log_stddev), FLAGS.info_reg_min_stddev, FLAGS.info_reg_max_stddev
        )
        graph.param.sampled_parameter = graph.param.parameter + tf.random.normal(graph.param.parameter.shape) * graph.param.stddev
        graph.param.mean_stddev = tf.reduce_mean(graph.param.stddev)

        graph.param.trainables.append(graph.param.log_stddev)
    else:
        graph.param.sampled_parameter = graph.param.parameter
        graph.param.mean_stddev = tf.constant(0.0, dtype=tf.float32)

    graph.param.norm = tf.norm(graph.param.parameter)

    graph.param.inner_lr = InnerLearningRate(
        init_lr=FLAGS.inner_lr, lr_type=FLAGS.inner_lr_type,
        lr_scaling=FLAGS.inner_lr_scaling, lr_scaling_coef=FLAGS.inner_lr_scaling_coef,
        max_batch_size=max(parse_variabe_batch_size(FLAGS.inner_batch_size)),
    )
    for var in graph.param.inner_lr.trainables:
        graph.param.trainables.append(var)


    graph.train = AttrDict()
    graph.train.global_step = tf.train.get_or_create_global_step()

    graph.train.placeholders = [MetaPlaceholder() for _ in range(FLAGS.meta_train_tasks)]


    accuracies = []
    losses = []
    for task_id in range(FLAGS.meta_train_tasks):
        with pm.build_parameterized(graph.param.sampled_parameter):
            logits = models[FLAGS.dataset](
                graph.train.placeholders[task_id].post_adapt_images, datasets[FLAGS.dataset].n_logits(),
                training=True
            )

        one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].post_adapt_labels, datasets[FLAGS.dataset].n_logits())
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels, logits=logits
        ))
        losses.append(loss)
        prediction = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.train.placeholders[task_id].post_adapt_labels), tf.float32))
        accuracies.append(accuracy)

    graph.train.pre_adapt_accuracy = tf.add_n(accuracies) / FLAGS.meta_train_tasks
    graph.train.pre_adapt_loss = tf.add_n(losses) / FLAGS.meta_train_tasks


    graph.train.updated_parameters = []
    adaptation_distances = []

    for task_id in range(FLAGS.meta_train_tasks):
        updated_parameter = graph.param.sampled_parameter
        for inner_step in range(FLAGS.inner_steps):
            def compute_loss_with_adaptation():
                with pm.build_parameterized(updated_parameter):
                    logits = models[FLAGS.dataset](
                        graph.train.placeholders[task_id].pre_adapt_images,
                        datasets[FLAGS.dataset].n_logits(),
                        training=True
                    )
                one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].pre_adapt_labels, datasets[FLAGS.dataset].n_logits())
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=one_hot_labels, logits=logits
                ))
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            def compute_loss_without_adaptation():
                loss = tf.constant(0.0, dtype=tf.float32)
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            loss = tf.cond(
                tf.greater(tf.size(graph.train.placeholders[task_id].pre_adapt_labels), 0),
                compute_loss_with_adaptation,
                compute_loss_without_adaptation,
            )

            grad = tf.gradients(loss, updated_parameter)[0]
            if grad is not None:
                if FLAGS.inner_clip_gradient > 0:
                    grad = tf.clip_by_norm(grad, clip_norm=FLAGS.inner_clip_gradient)

                inner_lr = graph.param.inner_lr.get_lr(
                    graph.train.placeholders[task_id].pre_adapt_images
                )
                updated_parameter -= inner_lr * grad
                graph.train.inner_lr = inner_lr
        graph.train.updated_parameters.append(updated_parameter)
        adaptation_distances.append(tf.norm(updated_parameter - graph.param.sampled_parameter))

    graph.train.adaptation_distance = tf.add_n(adaptation_distances) / FLAGS.meta_train_tasks


    accuracies = []
    losses = []
    for task_id in range(FLAGS.meta_train_tasks):
        with pm.build_parameterized(graph.train.updated_parameters[task_id]):
            logits = models[FLAGS.dataset](
                graph.train.placeholders[task_id].post_adapt_images,
                datasets[FLAGS.dataset].n_logits(),
                training=True
            )
        prediction = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.train.placeholders[task_id].post_adapt_labels), tf.float32))
        accuracies.append(accuracy)

        one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].post_adapt_labels, datasets[FLAGS.dataset].n_logits())
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels, logits=logits
        ))
        losses.append(loss)

    graph.train.post_adapt_accuracy = tf.add_n(accuracies) / FLAGS.meta_train_tasks
    graph.train.post_adapt_loss = tf.add_n(losses) / FLAGS.meta_train_tasks

    graph.train.adapt_loss_improvement = graph.train.pre_adapt_loss - graph.train.post_adapt_loss
    graph.train.adapt_accuracy_improvement = graph.train.post_adapt_accuracy - graph.train.pre_adapt_accuracy

    if FLAGS.info_reg > 0:
        graph.train.info_reg_loss = FLAGS.info_reg * kl_qp_gaussian(
            graph.param.parameter, graph.param.stddev,
            tf.zeros_like(graph.param.parameter), tf.ones_like(graph.param.parameter)
        )
    else:
        graph.train.info_reg_loss = tf.constant(0.0, dtype=tf.float32)

    if FLAGS.weight_decay > 0:
        graph.train.weight_decay_loss = FLAGS.weight_decay * tf.reduce_sum(
            tf.square(graph.param.parameter)
        )
    else:
        graph.train.weight_decay_loss = tf.constant(0.0, dtype=tf.float32)

    graph.train.total_loss = (
        graph.train.post_adapt_loss
        + graph.train.info_reg_loss
        + graph.train.weight_decay_loss
    )

    graph.train.optimizer = tf.train.AdamOptimizer(
        tf.train.exponential_decay(
            FLAGS.outer_lr, graph.train.global_step, FLAGS.outer_lr_decay_step,
            FLAGS.outer_lr_decay_rate, staircase=True
        )
    )


    graph.train.gradients = tf.gradients(
        graph.train.total_loss, graph.param.trainables
    )
    graph.train.gradient_norm = tf.linalg.global_norm(graph.train.gradients)
    if FLAGS.outer_clip_gradient > 0:
        graph.train.clipped_gradients, _ = tf.clip_by_global_norm(
            graph.train.gradients, clip_norm=FLAGS.outer_clip_gradient
        )
    else:
        graph.train.clipped_gradients = graph.train.gradients
    graph.train.train_op = graph.train.optimizer.apply_gradients(
        zip(graph.train.clipped_gradients, graph.param.trainables),
        global_step=graph.train.global_step
    )

    graph.init_op = tf.global_variables_initializer()


    ### Test ###

    graph.test = AttrDict()
    graph.test.placeholder = MetaPlaceholder()

    with pm.build_parameterized():
        logits = models[FLAGS.dataset](
            graph.test.placeholder.post_adapt_images,
            datasets[FLAGS.dataset].n_logits(),
            training=False
        )

    one_hot_labels = tf.one_hot(graph.test.placeholder.post_adapt_labels, datasets[FLAGS.dataset].n_logits())
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=one_hot_labels, logits=logits
    ))
    graph.test.pre_adapt_loss = loss

    prediction = tf.argmax(logits, axis=1)
    graph.test.pre_adapt_accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.test.placeholder.post_adapt_labels), tf.float32))


    graph.test.updated_parameters = []
    graph.test.updated_losses = []

    updated_parameter = graph.param.sampled_parameter
    for _ in range(FLAGS.inner_steps):

        def compute_loss_with_adaptation():
            with pm.build_parameterized(updated_parameter):
                logits = models[FLAGS.dataset](
                    graph.test.placeholder.pre_adapt_images,
                    datasets[FLAGS.dataset].n_logits(),
                    training=False
                )
            one_hot_labels = tf.one_hot(graph.test.placeholder.pre_adapt_labels, datasets[FLAGS.dataset].n_logits())
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=one_hot_labels, logits=logits
            ))
            if FLAGS.info_reg > 0:
                loss -= FLAGS.inner_step_likelihood_reg \
                        * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                        * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
            return loss

        def compute_loss_without_adaptation():
            loss = tf.constant(0.0, dtype=tf.float32)
            if FLAGS.info_reg > 0:
                loss -= FLAGS.inner_step_likelihood_reg \
                        * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                        * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
            return loss

        loss = tf.cond(
            tf.greater(tf.size(graph.test.placeholder.pre_adapt_labels), 0),
            compute_loss_with_adaptation,
            compute_loss_without_adaptation,
        )

        graph.test.updated_losses.append(loss)
        grad = tf.gradients(loss, updated_parameter)[0]

        if grad is not None:
            if FLAGS.inner_clip_gradient > 0:
                grad = tf.clip_by_norm(grad, clip_norm=FLAGS.inner_clip_gradient)

            inner_lr = graph.param.inner_lr.get_lr(
                graph.test.placeholder.pre_adapt_images
            )
            updated_parameter -= inner_lr * grad
            graph.test.inner_lr = inner_lr
        graph.test.updated_parameters.append(updated_parameter)

    with pm.build_parameterized(updated_parameter):
        logits = models[FLAGS.dataset](
            graph.test.placeholder.post_adapt_images,
            datasets[FLAGS.dataset].n_logits(),
            training=False
        )

    prediction = tf.argmax(logits, axis=1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.test.placeholder.post_adapt_labels), tf.float32))

    one_hot_labels = tf.one_hot(graph.test.placeholder.post_adapt_labels, datasets[FLAGS.dataset].n_logits())
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=one_hot_labels, logits=logits
    ))

    graph.test.post_adapt_accuracy = accuracy
    graph.test.post_adapt_loss = loss

    graph.test.adapt_loss_improvement = graph.test.pre_adapt_loss - graph.test.post_adapt_loss
    graph.test.adapt_accuracy_improvement = graph.test.post_adapt_accuracy - graph.test.pre_adapt_accuracy

    return graph


def main(_):
    set_random_seed(FLAGS.seed)

    graph = build_graph()

    tb_logger = TensorBoardLogger(FLAGS.output_dir)

    with open(FLAGS.data_file, 'rb') as fin:
        data = pickle.load(fin)

    dataset = datasets[FLAGS.dataset](FLAGS.data_file, FLAGS.n_training_tasks)

    inner_batch_sizes = parse_variabe_batch_size(FLAGS.inner_batch_size)

    def train_ops(inner_batch_size):
        ops = {
            'param/norm': graph.param.norm,
            'param/mean_stddev': graph.param.mean_stddev,
            'param/weight_decay_loss': graph.train.weight_decay_loss,
            'train_{}_shot/pre_adapt_accuracy'.format(inner_batch_size): graph.train.pre_adapt_accuracy,
            'train_{}_shot/post_adapt_accuracy'.format(inner_batch_size): graph.train.post_adapt_accuracy,
            'train_{}_shot/post_adapt_loss'.format(inner_batch_size): graph.train.post_adapt_loss,
            'train_{}_shot/info_reg_loss'.format(inner_batch_size): graph.train.info_reg_loss,
            'train_{}_shot/gradient_norm'.format(inner_batch_size): graph.train.gradient_norm,
            'train_{}_shot/adaptation_distance'.format(inner_batch_size): graph.train.adaptation_distance,
            'train_{}_shot/adapt_loss_improvement'.format(inner_batch_size): graph.train.adapt_loss_improvement,
            'train_{}_shot/adapt_accuracy_improvement'.format(inner_batch_size): graph.train.adapt_accuracy_improvement,
            'train_{}_shot/inner_lr'.format(inner_batch_size): graph.train.inner_lr,
            'train_op': graph.train.train_op,
        }
        return ops

    def test_ops(inner_batch_size, partition='test'):
        return {
            '{}_{}_shot/pre_adapt_accuracy'.format(partition, inner_batch_size): graph.test.pre_adapt_accuracy,
            '{}_{}_shot/pre_adapt_loss'.format(partition, inner_batch_size): graph.test.pre_adapt_loss,
            '{}_{}_shot/post_adapt_accuracy'.format(partition, inner_batch_size): graph.test.post_adapt_accuracy,
            '{}_{}_shot/post_adapt_loss'.format(partition, inner_batch_size): graph.test.post_adapt_loss,
            '{}_{}_shot/adapt_loss_improvement'.format(partition, inner_batch_size): graph.test.adapt_loss_improvement,
            '{}_{}_shot/adapt_accuracy_improvement'.format(partition, inner_batch_size): graph.test.adapt_accuracy_improvement,
            '{}_{}_shot/inner_lr'.format(partition, inner_batch_size): graph.test.inner_lr,
        }

    def test_network(partition):
        for inner_batch_size in inner_batch_sizes:
            test_metrics = []
            for _ in range(FLAGS.test_batches):
                feed_dict = graph.test.placeholder(
                    *dataset.sample_from_task(
                        partition, inner_batch_size, FLAGS.outer_batch_size
                    )
                )
                test_metrics.append(
                    sess.run(test_ops(inner_batch_size, partition), feed_dict)
                )
            average_test_metric = average_dict(test_metrics)
            tb_logger.log_dict(train_step, average_test_metric)
        tb_logger.flush()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = FLAGS.gpu_mem_growth
    sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.max_gpu_mem_frac

    with tf.Session(config=sess_config) as sess:
        sess.run(graph.init_op)

        for train_step in range(FLAGS.n_steps):

            feed_dicts = []

            inner_batch_size = inner_batch_sizes[train_step % len(inner_batch_sizes)]
            data_batches = dataset.sample_multiple_tasks(
                'train', FLAGS.meta_train_tasks, inner_batch_size, FLAGS.outer_batch_size
            )

            for i, data_batch in enumerate(data_batches):
                feed_dicts.append(graph.train.placeholders[i](*data_batch))

            feed_dict = merge_dict(feed_dicts)

            train_metrics = sess.run(train_ops(inner_batch_size), feed_dict)
            del train_metrics['train_op']
            exist_nan = not np.all([np.isfinite(v) for k, v in train_metrics.items()])


            if train_step % FLAGS.log_interval < len(inner_batch_sizes) or exist_nan:
                tb_logger.log_dict(train_step, train_metrics)

            if exist_nan:
                print(train_metrics)
                exit()

            if train_step % FLAGS.test_interval == 0:
                test_network('val')
                test_network('test')

    print('Training completed!')

if __name__ == '__main__':
    absl.app.run(main)
