# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://nvlabs.github.io/stylegan2/license.html

"""Helper wrapper for a Tensorflow optimizer."""

import numpy as np
import tensorflow as tf

from collections import OrderedDict
from typing import List, Union

from . import autosummary
from . import tfutil
from .. import util

from .tfutil import TfExpression, TfExpressionEx

try:
    # TensorFlow 1.13
    from tensorflow.python.ops import nccl_ops
except:
    # Older TensorFlow versions
    import tensorflow.contrib.nccl as nccl_ops

class Optimizer:
    """A Wrapper for tf.train.Optimizer.

    Automatically takes care of:
    - Gradient averaging for multi-GPU training.
    - Gradient accumulation for arbitrarily large minibatches.
    - Dynamic loss scaling and typecasts for FP16 training.
    - Ignoring corrupted gradients that contain NaNs/Infs.
    - Reporting statistics.
    - Well-chosen default settings.
    """

    def __init__(self,
        name:                   str             = "Train",                  # Name string that will appear in TensorFlow graph.
        tf_optimizer:           str             = "tf.train.AdamOptimizer", # Underlying optimizer class.
        learning_rate:          TfExpressionEx  = 0.001,                    # Learning rate. Can vary over time.
        minibatch_multiplier:   TfExpressionEx  = None,                     # Treat N consecutive minibatches as one by accumulating gradients.
        share:                  "Optimizer"     = None,                     # Share internal state with a previously created optimizer?
        use_loss_scaling:       bool            = False,                    # Enable dynamic loss scaling for robust mixed-precision training?
        loss_scaling_init:      float           = 64.0,                     # Log2 of initial loss scaling factor.
        loss_scaling_inc:       float           = 0.0005,                   # Log2 of per-minibatch loss scaling increment when there is no overflow.
        loss_scaling_dec:       float           = 1.0,                      # Log2 of per-minibatch loss scaling decrement when there is an overflow.
        report_mem_usage:       bool            = False,                    # Report fine-grained memory usage statistics in TensorBoard?
        **kwargs):

        # Public fields.
        self.name                   = name
        self.learning_rate          = learning_rate
        self.minibatch_multiplier   = minibatch_multiplier
        self.id                     = self.name.replace("/", ".")
        self.scope                  = tf.get_default_graph().unique_name(self.id)
        self.optimizer_class        = util.get_obj_by_name(tf_optimizer)
        self.optimizer_kwargs       = dict(kwargs)
        self.use_loss_scaling       = use_loss_scaling
        self.loss_scaling_init      = loss_scaling_init
        self.loss_scaling_inc       = loss_scaling_inc
        self.loss_scaling_dec       = loss_scaling_dec

        # Private fields.
        self._updates_applied       = False
        self._devices               = OrderedDict() # device_name => EasyDict()
        self._shared_optimizers     = OrderedDict() # device_name => optimizer_class
        self._gradient_shapes       = None          # [shape, ...]
        self._report_mem_usage      = report_mem_usage

        # Validate arguments.
        assert callable(self.optimizer_class)

        # Share internal state if requested.
        if share is not None:
            assert isinstance(share, Optimizer)
            assert self.optimizer_class is share.optimizer_class
            assert self.learning_rate is share.learning_rate
            assert self.optimizer_kwargs == share.optimizer_kwargs
            self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access

    def _get_device(self, device_name: str):
        """Get internal state for the given TensorFlow device."""
        tfutil.assert_tf_initialized()
        if device_name in self._devices:
            return self._devices[device_name]

        # Initialize fields.
        device = util.EasyDict()
        device.name             = device_name
        device.optimizer        = None          # Underlying optimizer:     optimizer_class
        device.loss_scaling_var = None          # Log2 of loss scaling:     tf.Variable
        device.grad_raw         = OrderedDict() # Raw gradients:            var => [grad, ...]
        device.grad_clean       = OrderedDict() # Clean gradients:          var => grad
        device.grad_acc_vars    = OrderedDict() # Accumulation sums:        var => tf.Variable
        device.grad_acc_count   = None          # Accumulation counter:     tf.Variable
        device.grad_acc         = OrderedDict() # Accumulated gradients:    var => grad

        # Setup TensorFlow objects.
        with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
            if device_name not in self._shared_optimizers:
                optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
                self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
            device.optimizer = self._shared_optimizers[device_name]
            if self.use_loss_scaling:
                device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")

        # Register device.
        self._devices[device_name] = device
        return device

    def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
        """Register the gradients of the given loss function with respect to the given variables.
        Intended to be called once per GPU."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        device = self._get_device(loss.device)

        # Validate trainables.
        if isinstance(trainable_vars, dict):
            trainable_vars = list(trainable_vars.values())  # allow passing in Network.trainables as vars
        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
        assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
        assert all(var.device == device.name for var in trainable_vars)

        # Validate shapes.
        if self._gradient_shapes is None:
            self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
        assert len(trainable_vars) == len(self._gradient_shapes)
        assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))

        # Report memory usage if requested.
        deps = []
        if self._report_mem_usage:
            self._report_mem_usage = False
            try:
                with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
                    deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
            except tf.errors.NotFoundError:
                pass

        # Compute gradients.
        with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
            gate = tf.train.Optimizer.GATE_NONE  # disable gating to reduce memory usage
            grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)

        # Register gradients.
        for grad, var in grad_list:
            if var not in device.grad_raw:
                device.grad_raw[var] = []
            device.grad_raw[var].append(grad)

    def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
        """Construct training op to update the registered variables based on their gradients."""
        tfutil.assert_tf_initialized()
        assert not self._updates_applied
        self._updates_applied = True
        all_ops = []

        # Check for no-op.
        if allow_no_op and len(self._devices) == 0:
            with tfutil.absolute_name_scope(self.scope):
                return tf.no_op(name='TrainingOp')

        # Clean up gradients.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
                for var, grad in device.grad_raw.items():

                    # Filter out disconnected gradients and convert to float32.
                    grad = [g for g in grad if g is not None]
                    grad = [tf.cast(g, tf.float32) for g in grad]

                    # Sum within the device.
                    if len(grad) == 0:
                        grad = tf.zeros(var.shape)  # No gradients => zero.
                    elif len(grad) == 1:
                        grad = grad[0]              # Single gradient => use as is.
                    else:
                        grad = tf.add_n(grad)       # Multiple gradients => sum.

                    # Scale as needed.
                    scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
                    scale = tf.constant(scale, dtype=tf.float32, name="scale")
                    if self.minibatch_multiplier is not None:
                        scale /= tf.cast(self.minibatch_multiplier, tf.float32)
                    scale = self.undo_loss_scaling(scale)
                    device.grad_clean[var] = grad * scale

        # Sum gradients across devices.
        if len(self._devices) > 1:
            with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
                for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
                    if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
                        all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
                        all_grads = nccl_ops.all_sum(all_grads)
                        for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
                            device.grad_clean[var] = grad

        # Apply updates separately on each device.
        for device_idx, device in enumerate(self._devices.values()):
            with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
                # pylint: disable=cell-var-from-loop

                # Accumulate gradients over time.
                if self.minibatch_multiplier is None:
                    acc_ok = tf.constant(True, name='acc_ok')
                    device.grad_acc = OrderedDict(device.grad_clean)
                else:
                    # Create variables.
                    with tf.control_dependencies(None):
                        for var in device.grad_clean.keys():
                            device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
                        device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")

                    # Track counter.
                    count_cur = device.grad_acc_count + 1.0
                    count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
                    count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
                    acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
                    all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))

                    # Track gradients.
                    for var, grad in device.grad_clean.items():
                        acc_var = device.grad_acc_vars[var]
                        acc_cur = acc_var + grad
                        device.grad_acc[var] = acc_cur
                        with tf.control_dependencies([acc_cur]):
                            acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
                            acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
                            all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))

                # No overflow => apply gradients.
                all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
                apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
                all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))

                # Adjust loss scaling.
                if self.use_loss_scaling:
                    ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
                    ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
                    ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
                    all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))

                # Last device => report statistics.
                if device_idx == len(self._devices) - 1:
                    all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
                    all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
                    if self.use_loss_scaling:
                        all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))

        # Initialize variables.
        self.reset_optimizer_state()
        if self.use_loss_scaling:
            tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
        if self.minibatch_multiplier is not None:
            tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])

        # Group everything into a single op.
        with tfutil.absolute_name_scope(self.scope):
            return tf.group(*all_ops, name="TrainingOp")

    def reset_optimizer_state(self) -> None:
        """Reset internal state of the underlying optimizer."""
        tfutil.assert_tf_initialized()
        tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])

    def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
        """Get or create variable representing log2 of the current dynamic loss scaling factor."""
        return self._get_device(device).loss_scaling_var

    def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
        """Apply dynamic loss scaling for the given expression."""
        assert tfutil.is_tf_expression(value)
        if not self.use_loss_scaling:
            return value
        return value * tfutil.exp2(self.get_loss_scaling_var(value.device))

    def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
        """Undo the effect of dynamic loss scaling for the given expression."""
        assert tfutil.is_tf_expression(value)
        if not self.use_loss_scaling:
            return value
        return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type


class SimpleAdam:
    """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""

    def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.name = name
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.all_state_vars = []

    def variables(self):
        return self.all_state_vars

    def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
        assert gate_gradients == tf.train.Optimizer.GATE_NONE
        return list(zip(tf.gradients(loss, var_list), var_list))

    def apply_gradients(self, grads_and_vars):
        with tf.name_scope(self.name):
            state_vars = []
            update_ops = []

            # Adjust learning rate to deal with startup bias.
            with tf.control_dependencies(None):
                b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
                b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
                state_vars += [b1pow_var, b2pow_var]
            b1pow_new = b1pow_var * self.beta1
            b2pow_new = b2pow_var * self.beta2
            update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
            lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)

            # Construct ops to update each variable.
            for grad, var in grads_and_vars:
                with tf.control_dependencies(None):
                    m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
                    v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
                    state_vars += [m_var, v_var]
                m_new = self.beta1 * m_var + (1 - self.beta1) * grad
                v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
                var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
                update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]

            # Group everything together.
            self.all_state_vars += state_vars
            return tf.group(*update_ops)
