import pickle
import os

import numpy as np
import tensorflow as tf

import tqdm

import absl.app
import absl.flags
from absl import logging
from collections import OrderedDict

from parameterized_model import ParameterizedModel

from .misc_utils import AttrDict
from .misc_utils import define_flags_with_default, set_random_seed, TensorBoardLogger, print_flags
from .online_data_generator import OnlineDataGenerator

FLAGS = absl.flags.FLAGS


flags_def = define_flags_with_default(
    inner_steps=1,
    inner_lr=3e-2,
    outer_lr=1e-4,
    inner_lr_type='fixed',
    inner_lr_scaling='exp',
    inner_lr_scaling_coef=1.0,
    inner_clip_gradient=0.0,
    inner_clip_gradient_norm=True,
    outer_clip_gradient=0.0,
    meta_train_tasks=25,
    per_task_batch_size=4,
    inner_batch_size='0-2-4-6-8-10-12-14-16-18-20',
    outer_batch_size=4,
    inner_step_likelihood_reg=0.0,
    info_reg=0.0,
    info_reg_init_stddev=1e-4,
    info_reg_min_stddev=1e-6,
    info_reg_max_stddev=1.0,
    batchnorm=False,
    weight_decay=0.0,
    n_steps=100000,
    test_interval=10,
    test_batches=5,
    log_interval=5,
    print_interval=100,
    seed=42,
    max_gpu_mem_frac=1.0,
    gpu_mem_growth=True,
    train_only_on_cur=False,
    cont_incl_cur=True,
    pretrain_steps=10000,
    task_steps=2000,
    batch_steps=10,
    cur_task=0,
    adaptive_advance=False,
    uncertainty_advance=False,
    advance_thresh=0.9,
    datasource='cont_rainbow_mnist',
    data_file='./rainbow_mnist.pkl',
    output_dir='/tmp/vs_maml_online',
)

def parse_variabe_batch_size(batch_sizes):
    if batch_sizes == '':
        return None
    return [int(x) for x in batch_sizes.split('-')]


def convnet(input_tensor, n_logits=10):
    x = input_tensor

    for i in range(4):
        x = tf.layers.conv2d(x, 32, 3)
        x = tf.nn.leaky_relu(x)

    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, 10)
    return x

def lr_convnet(input_tensor):
    x = input_tensor

    for i in range(3):
        x = tf.layers.conv2d(x, 32, 3)
        x = tf.nn.leaky_relu(x)

    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, 128)
    x = tf.nn.leaky_relu(x)

    x = tf.reduce_mean(x, axis=0, keepdims=True)

    x = tf.layers.dense(x, 64)
    x = tf.nn.leaky_relu(x)
    x = tf.layers.dense(
        x, 1,
        kernel_initializer=tf.initializers.constant(0.0),
        bias_initializer=tf.initializers.constant(0.0)
    )
    return tf.squeeze(x, axis=0)


class MetaPlaceholder(object):

    def __init__(self):
        self.pre_adapt_images = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 3])
        self.pre_adapt_labels = tf.placeholder(dtype=tf.int64, shape=[None])
        self.post_adapt_images = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 3])
        self.post_adapt_labels = tf.placeholder(dtype=tf.int64, shape=[None])

    def __call__(self, pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels):
        return {
            self.pre_adapt_images: pre_adapt_images,
            self.pre_adapt_labels: pre_adapt_labels,
            self.post_adapt_images: post_adapt_images,
            self.post_adapt_labels: post_adapt_labels,
        }

class InnerLearningRate(object):

    def __init__(self, init_lr, lr_type='fixed', lr_scaling='exp',
                 lr_scaling_coef=1.0, max_batch_size=20):
        self.init_lr = init_lr
        self.lr_type = lr_type
        self.lr_scaling = lr_scaling
        self.lr_scaling_coef = lr_scaling_coef
        self.max_batch_size = max_batch_size
        self._trainables = []

        if lr_type == 'fixed':
            self.build_fixed_lr()
        elif lr_type == 'table':
            self.build_table_lr()
        elif lr_type == 'manual':
            self.build_manual_lr()
        elif lr_type in {'manual_learned_coef', 'manual_theory_learned_coef', 'fixed_theory_learned_coef'}:
            self.build_manual_learned_coef_lr()
        elif lr_type == 'nn':
            self.build_nn_lr()
        else:
            raise ValueError('Unsupported learning rate type!')

    @property
    def trainables(self):
        return self._trainables

    def unscaled_init_lr(self):
        if self.lr_scaling == 'exp':
            return np.log(self.init_lr) / self.lr_scaling_coef
        elif self.lr_scaling == 'linear':
            return self.init_lr / self.lr_scaling_coef
        else:
            raise ValueError('Unsupported learning rate scaling!')

    def scale_lr(self, unscaled_lr):
        if self.lr_scaling == 'exp':
            return tf.exp(unscaled_lr * self.lr_scaling_coef)
        elif self.lr_scaling == 'linear':
            return tf.nn.relu(unscaled_lr * self.lr_scaling_coef)
        else:
            raise ValueError('Unsupported learning rate scaling!')

    def build_fixed_lr(self):
        self.lr = tf.constant(self.init_lr, dtype=tf.float32)

    def build_table_lr(self):
        self.lr_variable = tf.Variable(
            tf.ones(self.max_batch_size + 1) * self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.lr = self.scale_lr(self.lr_variable)
        self.trainables.append(self.lr_variable)

    def build_manual_lr(self):
        self.lr_variable = tf.Variable(
            tf.ones(2) * self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.lr = self.scale_lr(self.lr_variable)
        self.trainables.append(self.lr_variable)

    def build_manual_learned_coef_lr(self):
        if self.lr_type == 'fixed_theory_learned_coef':
            self.lr = tf.constant(self.init_lr, dtype=tf.float32)
        else:
            self.lr_variable = tf.Variable(
                tf.ones(2) * self.unscaled_init_lr(),
                dtype=tf.float32, trainable=True
            )
            self.lr = self.scale_lr(self.lr_variable)
            self.trainables.append(self.lr_variable)
        if self.lr_type == 'manual_theory_learned_coef':
            self.manual_coef_var = tf.Variable(
                1.0,
                dtype=tf.float32, trainable=True
            )
            self.manual_coef = tf.abs(self.manual_coef_var)
        else:
            self.manual_coef_var = tf.Variable(
                0.1,
                dtype=tf.float32, trainable=True
            )
            self.manual_coef = tf.abs(self.manual_coef_var) + 1.0
        self.trainables.append(self.manual_coef_var)

    def build_nn_lr(self):
        self.pm = ParameterizedModel('learning_rate_net')
        with self.pm.build_template():
            lr_convnet(tf.zeros([1, 28, 28, 3]))
        self.trainables.append(self.pm.parameter)
        self.zero_shot_lr_variable = tf.Variable(
            self.unscaled_init_lr(),
            dtype=tf.float32, trainable=True
        )
        self.zero_shot_lr = self.scale_lr(self.zero_shot_lr_variable)
        self.trainables.append(self.zero_shot_lr_variable)

    def get_lr(self, batch):
        if self.lr_type == 'fixed':
            return self.lr
        elif self.lr_type == 'table':
            batch_size = tf.shape(batch)[0]
            return self.lr[tf.minimum(batch_size, self.max_batch_size)]
        elif self.lr_type == 'manual':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (tf.constant(1.1, dtype=tf.float32) - 1 / batch_size),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (self.manual_coef - 1 / batch_size),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'manual_theory_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr[1] * (1. - 1. / (1 + batch_size*self.manual_coef)),
                lambda :self.lr[0]
            )
        elif self.lr_type == 'fixed_theory_learned_coef':
            batch_size = tf.cast(tf.shape(batch)[0], tf.float32)
            return tf.cond(
                tf.greater(batch_size, 0),
                lambda :self.lr * (1. - 1. / (1 + batch_size*self.manual_coef)),
                lambda :self.lr
            )
        elif self.lr_type == 'nn':
            batch_size = tf.shape(batch)[0]
            def nn_lr():
                with self.pm.build_parameterized():
                    raw_nn_lr = lr_convnet(batch)
                return self.scale_lr(raw_nn_lr + self.unscaled_init_lr())
            def zero_shot_lr():
                return self.zero_shot_lr
            return tf.cond(
                tf.greater(batch_size, 0), nn_lr, zero_shot_lr
            )
        else:
            raise ValueError('Unsupported learning rate type!')


def merge_dict(dicts):
    merged = {}
    for d in dicts:
        merged.update(d)
    return merged

def average_dict(dicts):
    keys = set(dicts[0].keys())
    sum_dict = {key: 0 for key in keys}
    for d in dicts:
        assert set(d.keys()) == keys
        for key in keys:
            sum_dict[key] += d[key]
    for key in keys:
        sum_dict[key] = sum_dict[key] / len(dicts)
    return sum_dict

def kl_qp_gaussian(mu_q, sigma_q, mu_p, sigma_p):
      """Kullback-Leibler KL(N(mu_q), Diag(sigma_q^2) || N(mu_p), Diag(sigma_p^2))"""
      sigma2_q = tf.square(sigma_q) + 1e-16
      sigma2_p = tf.square(sigma_p) + 1e-16
      temp = tf.math.log(sigma2_p) - tf.math.log(sigma2_q) - 1.0 + \
              sigma2_q / sigma2_p + tf.square(mu_q - mu_p) / sigma2_p  #n_target * d_w
      kl = 0.5 * tf.reduce_mean(temp)
      return kl


def build_graph():
    graph = AttrDict()
    pm = ParameterizedModel(name='maml_rainbow_mnist')

    with pm.build_template():
        convnet(tf.zeros([1, 28, 28, 3]))

    graph.param = AttrDict()
    graph.param.parameterized_model = pm
    graph.param.parameter = pm.parameter

    graph.param.trainables = [graph.param.parameter]

    if FLAGS.info_reg > 0:
        graph.param.log_stddev = tf.Variable(
            tf.ones_like(graph.param.parameter) * np.log(FLAGS.info_reg_init_stddev),
            dtype=tf.float32, trainable=True
        )
        graph.param.stddev = tf.clip_by_value(
            tf.exp(graph.param.log_stddev), FLAGS.info_reg_min_stddev, FLAGS.info_reg_max_stddev
        )
        graph.param.sampled_parameter = graph.param.parameter + tf.random.normal(graph.param.parameter.shape) * graph.param.stddev
        graph.param.mean_stddev = tf.reduce_mean(graph.param.stddev)

        graph.param.trainables.append(graph.param.log_stddev)
    else:
        graph.param.sampled_parameter = graph.param.parameter
        graph.param.mean_stddev = tf.constant(0.0, dtype=tf.float32)

    graph.param.norm = tf.norm(graph.param.parameter)

    graph.param.inner_lr = InnerLearningRate(
        init_lr=FLAGS.inner_lr, lr_type=FLAGS.inner_lr_type,
        lr_scaling=FLAGS.inner_lr_scaling, lr_scaling_coef=FLAGS.inner_lr_scaling_coef,
        max_batch_size=max(parse_variabe_batch_size(FLAGS.inner_batch_size)),
    )
    for var in graph.param.inner_lr.trainables:
        graph.param.trainables.append(var)


    graph.train = AttrDict()
    graph.train.placeholders = [MetaPlaceholder() for _ in range(FLAGS.meta_train_tasks)]


    accuracies = []
    losses = []
    for task_id in range(FLAGS.meta_train_tasks):
        with pm.build_parameterized(graph.param.sampled_parameter):
            logits = convnet(graph.train.placeholders[task_id].post_adapt_images)

        one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].post_adapt_labels, 10)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels, logits=logits
        ))
        losses.append(loss)
        prediction = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.train.placeholders[task_id].post_adapt_labels), tf.float32))
        accuracies.append(accuracy)

    graph.train.pre_adapt_accuracy = tf.add_n(accuracies) / FLAGS.meta_train_tasks
    graph.train.pre_adapt_loss = tf.add_n(losses) / FLAGS.meta_train_tasks


    graph.train.updated_parameters = []
    adaptation_distances = []

    for task_id in range(FLAGS.meta_train_tasks):
        updated_parameter = graph.param.sampled_parameter
        for inner_step in range(FLAGS.inner_steps):
            def compute_loss_with_adaptation():
                with pm.build_parameterized(updated_parameter):
                    logits = convnet(graph.train.placeholders[task_id].pre_adapt_images)
                one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].pre_adapt_labels, 10)
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=one_hot_labels, logits=logits
                ))
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            def compute_loss_without_adaptation():
                loss = tf.constant(0.0, dtype=tf.float32)
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            loss = tf.cond(
                tf.greater(tf.size(graph.train.placeholders[task_id].pre_adapt_labels), 0),
                compute_loss_with_adaptation,
                compute_loss_without_adaptation,
            )

            grad = tf.gradients(loss, updated_parameter)[0]
            if grad is not None:
                if FLAGS.inner_clip_gradient > 0:
                    if FLAGS.inner_clip_gradient_norm:
                        grad = tf.clip_by_norm(grad, clip_norm=FLAGS.inner_clip_gradient)
                    else:
                        grad = tf.clip_by_value(grad, -FLAGS.inner_clip_gradient, FLAGS.inner_clip_gradient)
                inner_lr = graph.param.inner_lr.get_lr(
                    graph.train.placeholders[task_id].pre_adapt_images
                )
                updated_parameter -= inner_lr * grad
                graph.train.inner_lr = inner_lr
        graph.train.updated_parameters.append(updated_parameter)
        adaptation_distances.append(tf.norm(updated_parameter - graph.param.sampled_parameter))

    graph.train.adaptation_distance = tf.add_n(adaptation_distances) / FLAGS.meta_train_tasks


    accuracies = []
    losses = []
    for task_id in range(FLAGS.meta_train_tasks):
        with pm.build_parameterized(graph.train.updated_parameters[task_id]):
            logits = convnet(graph.train.placeholders[task_id].post_adapt_images)
        prediction = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.train.placeholders[task_id].post_adapt_labels), tf.float32))
        accuracies.append(accuracy)

        one_hot_labels = tf.one_hot(graph.train.placeholders[task_id].post_adapt_labels, 10)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels, logits=logits
        ))
        losses.append(loss)

    graph.train.post_adapt_accuracy = tf.add_n(accuracies) / FLAGS.meta_train_tasks
    graph.train.post_adapt_loss = tf.add_n(losses) / FLAGS.meta_train_tasks

    graph.train.adapt_loss_improvement = graph.train.pre_adapt_loss - graph.train.post_adapt_loss
    graph.train.adapt_accuracy_improvement = graph.train.post_adapt_accuracy - graph.train.pre_adapt_accuracy

    if FLAGS.info_reg > 0:
        graph.train.info_reg_loss = FLAGS.info_reg * kl_qp_gaussian(
            graph.param.parameter, graph.param.stddev,
            tf.zeros_like(graph.param.parameter), tf.ones_like(graph.param.parameter)
        )
    else:
        graph.train.info_reg_loss = tf.constant(0.0, dtype=tf.float32)

    if FLAGS.weight_decay > 0:
        graph.train.weight_decay_loss = FLAGS.weight_decay * tf.reduce_sum(
            tf.square(graph.param.parameter)
        )
    else:
        graph.train.weight_decay_loss = tf.constant(0.0, dtype=tf.float32)

    graph.train.total_loss = (
        graph.train.post_adapt_loss
        + graph.train.info_reg_loss
        + graph.train.weight_decay_loss
    )


    graph.train.optimizer = tf.train.AdamOptimizer(FLAGS.outer_lr)



    graph.train.gradients = tf.gradients(
        graph.train.total_loss, graph.param.trainables
    )
    graph.train.gradient_norm = tf.linalg.global_norm(graph.train.gradients)
    if FLAGS.outer_clip_gradient > 0:
        graph.train.clipped_gradients, _ = tf.clip_by_global_norm(
            graph.train.gradients, clip_norm=FLAGS.outer_clip_gradient
        )
    else:
        graph.train.clipped_gradients = graph.train.gradients
    graph.train.train_op = graph.train.optimizer.apply_gradients(
        zip(graph.train.clipped_gradients, graph.param.trainables)
    )

    graph.init_op = tf.global_variables_initializer()


    ### Test ###

    graph.test = AttrDict()
    graph.test.placeholder = MetaPlaceholder()

    with pm.build_parameterized():
        logits = convnet(graph.test.placeholder.post_adapt_images)
    one_hot_labels = tf.one_hot(graph.test.placeholder.post_adapt_labels, 10)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=one_hot_labels, logits=logits
    ))
    graph.test.pre_adapt_loss = loss
    prediction = tf.argmax(logits, axis=1)
    graph.test.pre_adapt_accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.test.placeholder.post_adapt_labels), tf.float32))


    graph.test.updated_parameters = []
    graph.test.updated_losses = []

    updated_parameter = graph.param.sampled_parameter
    for _ in range(FLAGS.inner_steps):

        def compute_loss_with_adaptation():
            with pm.build_parameterized(updated_parameter):
                logits = convnet(graph.test.placeholder.pre_adapt_images)
            one_hot_labels = tf.one_hot(graph.test.placeholder.pre_adapt_labels, 10)
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                labels=one_hot_labels, logits=logits
            ))
            if FLAGS.info_reg > 0:
                loss -= FLAGS.inner_step_likelihood_reg \
                        * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                        * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
            return loss

        def compute_loss_without_adaptation():
            loss = tf.constant(0.0, dtype=tf.float32)
            if FLAGS.info_reg > 0:
                loss -= FLAGS.inner_step_likelihood_reg \
                        * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                        * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
            return loss

        loss = tf.cond(
            tf.greater(tf.size(graph.test.placeholder.pre_adapt_labels), 0),
            compute_loss_with_adaptation,
            compute_loss_without_adaptation,
        )

        graph.test.updated_losses.append(loss)
        grad = tf.gradients(loss, updated_parameter)[0]

        if grad is not None:
            if FLAGS.inner_clip_gradient > 0:
                grad = tf.clip_by_norm(grad, clip_norm=FLAGS.inner_clip_gradient)
            inner_lr = graph.param.inner_lr.get_lr(
                graph.test.placeholder.pre_adapt_images
            )
            updated_parameter -= inner_lr * grad
            graph.test.inner_lr = inner_lr
        graph.test.updated_parameters.append(updated_parameter)

    with pm.build_parameterized(updated_parameter):
        logits = convnet(graph.test.placeholder.post_adapt_images)
    prediction = tf.argmax(logits, axis=1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, graph.test.placeholder.post_adapt_labels), tf.float32))

    one_hot_labels = tf.one_hot(graph.test.placeholder.post_adapt_labels, 10)
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=one_hot_labels, logits=logits
    ))

    graph.test.post_adapt_accuracy = accuracy
    graph.test.post_adapt_loss = loss

    graph.test.adapt_loss_improvement = graph.test.pre_adapt_loss - graph.test.post_adapt_loss
    graph.test.adapt_accuracy_improvement = graph.test.post_adapt_accuracy - graph.test.pre_adapt_accuracy

    if FLAGS.uncertainty_advance:
        assert FLAGS.info_reg
        graph.param.sampled_parameter_second = graph.param.parameter + tf.random.normal(graph.param.parameter.shape) * graph.param.stddev
        updated_parameter = graph.param.sampled_parameter_second
        for _ in range(FLAGS.inner_steps):

            def compute_loss_with_adaptation():
                with pm.build_parameterized(updated_parameter):
                    logits = convnet(graph.test.placeholder.pre_adapt_images)
                one_hot_labels = tf.one_hot(graph.test.placeholder.pre_adapt_labels, 10)
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
                    labels=one_hot_labels, logits=logits
                ))
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            def compute_loss_without_adaptation():
                loss = tf.constant(0.0, dtype=tf.float32)
                if FLAGS.info_reg > 0:
                    loss -= FLAGS.inner_step_likelihood_reg \
                            * (2 * np.pi) ** (-0.5 * pm.parameter_size) \
                            * tf.exp(-0.5 * tf.reduce_sum(tf.square(updated_parameter)))
                return loss

            loss = tf.cond(
                tf.greater(tf.size(graph.test.placeholder.pre_adapt_labels), 0),
                compute_loss_with_adaptation,
                compute_loss_without_adaptation,
            )

            grad = tf.gradients(loss, updated_parameter)[0]

            if grad is not None:
                if FLAGS.inner_clip_gradient > 0:
                    grad = tf.clip_by_norm(grad, clip_norm=FLAGS.inner_clip_gradient)
                inner_lr = graph.param.inner_lr.get_lr(
                    graph.train.placeholders[task_id].pre_adapt_images
                )
                updated_parameter -= inner_lr * grad
                graph.train.inner_lr = inner_lr

        with pm.build_parameterized(updated_parameter):
            logits = convnet(graph.test.placeholder.post_adapt_images)
        prediction_second = tf.argmax(logits, axis=1)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction_second, graph.test.placeholder.post_adapt_labels), tf.float32))
        alignment_accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction_second, prediction), tf.float32))

        one_hot_labels = tf.one_hot(graph.test.placeholder.post_adapt_labels, 10)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_labels, logits=logits
        ))

        graph.test.post_adapt_accuracy_second = accuracy
        graph.test.post_adapt_alignment_accuracy = alignment_accuracy

    return graph


def main(_):
    set_random_seed(FLAGS.seed)

    graph = build_graph()

    output_dir = FLAGS.output_dir + 'mbs_{}_inner_steps_{}__inner_lr_{}_inner_step_reg_{}_info_reg_{}_inner_lr_{}_seed_{}'.format(FLAGS.meta_train_tasks,
                                                                                                                                                    FLAGS.inner_steps,
                                                                                                                                                    FLAGS.inner_lr,
                                                                                                                                                    FLAGS.inner_step_likelihood_reg,
                                                                                                                                                    FLAGS.info_reg,
                                                                                                                                                    FLAGS.inner_lr_type,
                                                                                                                                                    FLAGS.seed)
    if FLAGS.uncertainty_advance:
        output_dir += '_uncertain_advance_{}'.format(FLAGS.advance_thresh)
    elif FLAGS.adaptive_advance:
        output_dir += '_adaptive_advance_{}'.format(FLAGS.advance_thresh)
    if FLAGS.weight_decay > 0:
        output_dir += '_weight_decay_{}'.format(FLAGS.weight_decay)
    if FLAGS.outer_clip_gradient > 0.:
        output_dir += '_outer_grad_clip_{}'.format(FLAGS.outer_clip_gradient)
    if FLAGS.inner_clip_gradient > 0.:
        output_dir += '_inner_grad_clip_{}'.format(FLAGS.inner_clip_gradient)
        if FLAGS.inner_clip_gradient_norm:
            output_dir += '_norm'
        else:
            output_dir += '_val'
    inner_batch_sizes = parse_variabe_batch_size(FLAGS.inner_batch_size)
    output_dir += '_maxshot_{}'.format(inner_batch_sizes[-1])
    if FLAGS.cur_task > 0:
        output_dir += '_cur_task_{}_pretrain_steps_{}'.format(FLAGS.cur_task, FLAGS.pretrain_steps)

    tb_logger = TensorBoardLogger(output_dir)

    online_data_generator = OnlineDataGenerator(per_task_batch_size=FLAGS.per_task_batch_size)

    def train_ops(inner_batch_size):
        ops = {
            'param/norm': graph.param.norm,
            'param/mean_stddev': graph.param.mean_stddev,
            'param/weight_decay_loss': graph.train.weight_decay_loss,
            'train_{}_shot/pre_adapt_accuracy'.format(inner_batch_size): graph.train.pre_adapt_accuracy,
            'train_{}_shot/post_adapt_accuracy'.format(inner_batch_size): graph.train.post_adapt_accuracy,
            'train_{}_shot/post_adapt_loss'.format(inner_batch_size): graph.train.post_adapt_loss,
            'train_{}_shot/info_reg_loss'.format(inner_batch_size): graph.train.info_reg_loss,
            'train_{}_shot/gradient_norm'.format(inner_batch_size): graph.train.gradient_norm,
            'train_{}_shot/adaptation_distance'.format(inner_batch_size): graph.train.adaptation_distance,
            'train_{}_shot/adapt_loss_improvement'.format(inner_batch_size): graph.train.adapt_loss_improvement,
            'train_{}_shot/adapt_accuracy_improvement'.format(inner_batch_size): graph.train.adapt_accuracy_improvement,
            'train_{}_shot/inner_lr'.format(inner_batch_size): graph.train.inner_lr,
            'train_op': graph.train.train_op,
        }

        return ops

    def test_ops(inner_batch_size):
        return {
            'test_{}_shot/pre_adapt_accuracy'.format(inner_batch_size): graph.test.pre_adapt_accuracy,
            'test_{}_shot/pre_adapt_loss'.format(inner_batch_size): graph.test.pre_adapt_loss,
            'test_{}_shot/post_adapt_accuracy'.format(inner_batch_size): graph.test.post_adapt_accuracy,
            'test_{}_shot/post_adapt_loss'.format(inner_batch_size): graph.test.post_adapt_loss,
            'test_{}_shot/adapt_loss_improvement'.format(inner_batch_size): graph.test.adapt_loss_improvement,
            'test_{}_shot/adapt_accuracy_improvement'.format(inner_batch_size): graph.test.adapt_accuracy_improvement,
            'test_{}_shot/inner_lr'.format(inner_batch_size): graph.test.inner_lr,
        }

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = FLAGS.gpu_mem_growth
    sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.max_gpu_mem_frac

    pretrain_steps = FLAGS.pretrain_steps if FLAGS.cur_task != 0 else 0
    OFFSET = pretrain_steps + 1
    zero_shot_accuracy, post_adapt_accuracy, alignment_accuracy, val_inner_batch_size, last_val_inner_batch_size = 0., 0., 0., 0, 0
    regret = 0.
    variable_shot_regret = 0.
    switch_task_log = OrderedDict()
    with tf.Session(config=sess_config) as sess:
        sess.run(graph.init_op)

        for train_step in range(1, FLAGS.n_steps):
            if train_step >= pretrain_steps + 1:
                if (train_step >= FLAGS.task_steps // 2 and (train_step - OFFSET) % FLAGS.task_steps == 0) or \
                    (FLAGS.adaptive_advance and (zero_shot_accuracy >= FLAGS.advance_thresh or post_adapt_accuracy >= FLAGS.advance_thresh)) or \
                    (FLAGS.uncertainty_advance and alignment_accuracy >= FLAGS.advance_thresh):# and \
                    online_data_generator.add_task()
                    switch_task_log[online_data_generator.cur_task - 1] = train_step - 1
                    with open(os.path.join(tb_logger.output_dir, 'switch_task_log.txt'), 'w+') as f:
                        for key in switch_task_log:
                            f.write('Task %d: %d\n'  % (key, switch_task_log[key]))
                    zero_shot_accuracy = 0.

                    if FLAGS.adaptive_advance:
                        OFFSET = train_step % FLAGS.task_steps

                elif (train_step - OFFSET) % FLAGS.batch_steps == 0 and (train_step - OFFSET) % FLAGS.task_steps != 0 and FLAGS.datasource != 'cont_power_consumption':
                    online_data_generator.add_batch()

                if (train_step - OFFSET) % FLAGS.task_steps == 0:

                    pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels, inner_batch_size = online_data_generator.generate(inner_batch_sizes,
                                                                                                                                                    train=False,
                                                                                                                                                    test_zero_shot=True)
                    pre_adapt_accuracies = []
                    post_adapt_accuracies = []
                    post_adapt_losses = []
                    alignment_accuracies = []
                    for i in range(FLAGS.test_batches):
                        try:
                            assert len(pre_adapt_labels[i]) == 0
                        except:
                            import pdb; pdb.set_trace()
                        feed_dict = graph.test.placeholder(
                            pre_adapt_images[i], pre_adapt_labels[i],
                            post_adapt_images[i], post_adapt_labels[i]
                        )
                        pre_adapt_accuracy, post_adapt_accuracy, post_adapt_loss = sess.run(
                            [graph.test.pre_adapt_accuracy, graph.test.post_adapt_accuracy, graph.test.post_adapt_loss],
                            feed_dict
                        )
                        if FLAGS.uncertainty_advance:
                            alignment_accuracy = sess.run(
                                graph.test.post_adapt_alignment_accuracy,
                                feed_dict
                            )
                        else:
                            alignment_accuracy = 0.
                        pre_adapt_accuracies.append(pre_adapt_accuracy)
                        post_adapt_accuracies.append(post_adapt_accuracy)
                        post_adapt_losses.append(post_adapt_loss)
                        alignment_accuracies.append(alignment_accuracy)
                    pre_adapt_accuracy = np.mean(pre_adapt_accuracies)
                    post_adapt_accuracy = np.mean(post_adapt_accuracies)
                    post_adapt_loss = np.mean(post_adapt_loss)
                    alignment_accuracy = np.mean(alignment_accuracies)
                    regret += post_adapt_loss
                    variable_shot_regret += post_adapt_loss
                    last_val_inner_batch_size = inner_batch_size

                    tb_logger.log_scaler(
                        train_step,
                        'test_{}_shot/pre_adapt_accuracy'.format(inner_batch_size),
                        pre_adapt_accuracy
                    )
                    tb_logger.log_scaler(
                        train_step,
                        'test_{}_shot/post_adapt_accuracy'.format(inner_batch_size),
                        post_adapt_accuracy
                    )
                    tb_logger.log_scaler(
                        train_step,
                        'alignment_accuracy',
                        alignment_accuracy
                    )
                    tb_logger.log_scaler(
                        train_step,
                        'regret',
                        regret
                    )
                    tb_logger.log_scaler(
                        train_step,
                        'variable_shot_regret',
                        variable_shot_regret
                    )
                    tb_logger.flush()
                    print('Initial pre-adapt acc is %.3f, post-adapt acc is %.3f, align is %.3f' % (pre_adapt_accuracy, post_adapt_accuracy, alignment_accuracy))
                    zero_shot_accuracy = pre_adapt_accuracy
                    if (FLAGS.adaptive_advance and (zero_shot_accuracy >= FLAGS.advance_thresh or post_adapt_accuracy >= FLAGS.advance_thresh)) or \
                        (FLAGS.uncertainty_advance and alignment_accuracy >= FLAGS.advance_thresh):
                        continue

            feed_dicts = []
            pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels, inner_batch_size = online_data_generator.generate(inner_batch_sizes, train_step=train_step)
            for i in range(len(pre_adapt_images)):

                feed_dicts.append(
                    graph.train.placeholders[i](
                        pre_adapt_images[i], pre_adapt_labels[i],
                        post_adapt_images[i], post_adapt_labels[i]
                    )
                )

            feed_dict = merge_dict(feed_dicts)

            train_metrics = sess.run(train_ops(inner_batch_size), feed_dict)
            del train_metrics['train_op']
            exist_nan = not np.all([np.isfinite(v) for k, v in train_metrics.items()])


            if train_step % FLAGS.log_interval < len(inner_batch_sizes) or exist_nan:
                tb_logger.log_dict(train_step, train_metrics)

            if exist_nan:
                print(train_metrics)
                exit()

            # if train_step % FLAGS.test_interval == 0:
            if (train_step - OFFSET + 1) % FLAGS.batch_steps == 0:
                pre_adapt_images, pre_adapt_labels, post_adapt_images, post_adapt_labels, val_inner_batch_size = online_data_generator.generate(inner_batch_sizes, train=False)
                try:
                    assert val_inner_batch_size > 0
                except:
                    import pdb; pdb.set_trace()
                pre_adapt_accuracies = []
                post_adapt_accuracies = []
                zero_shot_accuracies = []
                post_adapt_losses = []
                alignment_accuracies = []
                for i in range(FLAGS.test_batches):
                    feed_dict = graph.test.placeholder(
                        pre_adapt_images[i], pre_adapt_labels[i],
                        post_adapt_images[i], post_adapt_labels[i]
                    )
                    pre_adapt_accuracy, post_adapt_accuracy, post_adapt_loss = sess.run(
                        [graph.test.pre_adapt_accuracy, graph.test.post_adapt_accuracy, graph.test.post_adapt_loss],
                        feed_dict
                    )
                    zero_shot_feed_dict = graph.test.placeholder(
                        pre_adapt_images[i][:0], pre_adapt_labels[i][:0],
                        post_adapt_images[i], post_adapt_labels[i]
                    )
                    _, zero_shot_accuracy = sess.run(
                        [graph.test.pre_adapt_accuracy, graph.test.post_adapt_accuracy],
                        zero_shot_feed_dict
                    )
                    if FLAGS.uncertainty_advance:
                        alignment_accuracy = sess.run(
                            graph.test.post_adapt_alignment_accuracy,
                            feed_dict
                        )
                    else:
                        alignment_accuracy = 0.
                    pre_adapt_accuracies.append(pre_adapt_accuracy)
                    post_adapt_accuracies.append(post_adapt_accuracy)
                    zero_shot_accuracies.append(zero_shot_accuracy)
                    post_adapt_losses.append(post_adapt_loss)
                    alignment_accuracies.append(alignment_accuracy)
                pre_adapt_accuracy = np.mean(pre_adapt_accuracies)
                post_adapt_accuracy = np.mean(post_adapt_accuracies)
                zero_shot_accuracy = np.mean(zero_shot_accuracies)
                post_adapt_loss = np.mean(post_adapt_losses)
                alignment_accuracy = np.mean(alignment_accuracies)
                regret += post_adapt_loss
                if last_val_inner_batch_size != inner_batch_sizes[-1]:
                    variable_shot_regret += post_adapt_loss
                tb_logger.log_scaler(
                    train_step,
                    'test_{}_shot/pre_adapt_accuracy'.format(val_inner_batch_size),
                    pre_adapt_accuracy
                )
                tb_logger.log_scaler(
                    train_step,
                    'test_{}_shot/post_adapt_accuracy'.format(val_inner_batch_size),
                    post_adapt_accuracy
                )
                tb_logger.log_scaler(
                    train_step,
                    'alignment_accuracy',
                    alignment_accuracy
                )
                tb_logger.log_scaler(
                    train_step,
                    'regret',
                    regret
                )
                if last_val_inner_batch_size != inner_batch_sizes[-1]:
                    tb_logger.log_scaler(
                        train_step,
                        'variable_shot_regret',
                        variable_shot_regret
                    )
                tb_logger.flush()
                last_val_inner_batch_size = val_inner_batch_size
            if train_step % FLAGS.print_interval == 0:
                print('Training step %d: 0-shot val acc is %.3f, %d-shot val acc is %.3f, align is %.3f' % (train_step, zero_shot_accuracy, val_inner_batch_size, post_adapt_accuracy, alignment_accuracy))
    print('Training completed!')

if __name__ == '__main__':
    absl.app.run(main)
