import torch
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import _amp_state, master_params, maybe_print
from itertools import product

def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
    # Exception handling for 18.04 compatibility
    if check_overflow:
        cpu_sum = float(model_grad.float().sum())
        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
            return True

    if master_grad is not model_grad: # copy_ probably internally short-circuits this
        master_grad.copy_(model_grad)
    if scale != 1.0:
        master_grad.mul_(scale)
    return False

def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
    # Exception handling for 18.04 compatibility
    if check_overflow:
        cpu_sum = float(model_grad.float().sum())
        if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
            return True

    # if master_grad is not model_grad: # copy_ probably internally short-circuits this
    #     master_grad.copy_(model_grad)
    assert stashed_grad.dtype == master_grad.dtype
    converted_model_grad = model_grad.data.to(master_grad.dtype)
    master_grad.data = a*converted_model_grad.data + b*stashed_grad.data
    return False

class LossScaler(object):
    warned_no_fused_kernel = False
    warned_unscaling_non_fp32_grad = False
    has_fused_kernel = False

    def __init__(self,
                 loss_scale,
                 init_scale=2.**16,
                 scale_factor=2.,
                 scale_window=2000,
                 min_loss_scale=None,
                 max_loss_scale=2.**24):
        if loss_scale == "dynamic":
            self.dynamic = True
            self._loss_scale = min(max_loss_scale, init_scale)
        else:
            self.dynamic = False
            self._loss_scale = loss_scale
        self._max_loss_scale = max_loss_scale
        self._min_loss_scale = min_loss_scale
        self._scale_seq_len = scale_window
        self._unskipped = 0
        self._has_overflow = False
        self._overflow_buf = torch.cuda.IntTensor([0])
        if multi_tensor_applier.available:
            import amp_C
            LossScaler.has_fused_kernel = multi_tensor_applier.available
            LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
            LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
        else:
            if not LossScaler.warned_no_fused_kernel:
                maybe_print(
                    "Warning:  multi_tensor_applier fused unscale kernel is unavailable, "
                    "possibly because apex was installed without --cuda_ext --cpp_ext. "
                    "Using Python fallback.  Original ImportError was: " +
                    repr(multi_tensor_applier.import_err),
                    True)
            LossScaler.has_fused_kernel = False
            LossScaler.warned_no_fused_kernel = True

    def loss_scale(self):
        return self._loss_scale

    def unscale_python(self, model_grads, master_grads, scale):
        for model, master in zip(model_grads, master_grads):
            if model is not None:
                if not LossScaler.warned_unscaling_non_fp32_grad:
                    if master.dtype != torch.float32:
                        maybe_print(
                            "Attempting to unscale a grad with type {} ".format(master.type()) +
                            "Unscaling non-fp32 grads may indicate an error. "
                            "When using Amp, you don't need to call .half() on your model.")
                        LossScaler.warned_unscaling_non_fp32_grad = True
                self._has_overflow = scale_check_overflow_python(model,
                                                                 master,
                                                                 1./scale,
                                                                 self.dynamic)
                if self._has_overflow and self.dynamic:
                    break

    # unused_scale keeps some of the old API alive for hopefully a short time.
    def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None):
        if self._has_overflow:
            return

        scale = self._loss_scale
        if scale_override is not None:
            scale = scale_override

        if scale == 1.0 and models_are_masters and not self.dynamic:
            return

        if LossScaler.has_fused_kernel:
            # if (not LossScaler.warned_unscaling_non_fp32_grad
            #     and master_grads[0].dtype == torch.float16):
            #     print("Warning:  unscaling grads that are not FP32. "
            #           "Unscaling non-fp32 grads may indicate an error. "
            #           "When using Amp, you don't need to call .half() on your model.")
            #     # Setting this to True unconditionally allows the possibility of an escape
            #     # if never-before-seen non-fp32 grads are created in some later iteration.
            #     LossScaler.warned_unscaling_non_fp32_grad = True
            multi_tensor_applier(LossScaler.multi_tensor_scale_cuda,
                                 self._overflow_buf,
                                 [model_grads, master_grads],
                                 1./scale)
        else:
            self.unscale_python(model_grads, master_grads, scale)

        # Defer to update_scale
        # If the fused kernel is available, we only need one D2H memcopy and sync.
        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
        #     self._has_overflow = self._overflow_buf.item()

    def unscale_with_stashed_python(self,
                                    model_grads,
                                    stashed_master_grads,
                                    master_grads,
                                    a,
                                    b):
        for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads):
            if model is None and stashed is None:
                continue
            else:
                if not LossScaler.warned_unscaling_non_fp32_grad:
                    if master.dtype != torch.float32:
                        maybe_print(
                            "Attempting to unscale a grad with type {} ".format(master.type()) +
                            "Unscaling non-fp32 grads may indicate an error. "
                            "When using Amp, you don't need to call .half() on your model.")
                        LossScaler.warned_unscaling_non_fp32_grad = True
                self._has_overflow = axpby_check_overflow_python(model,
                                                                 stashed,
                                                                 master,
                                                                 a,
                                                                 b,
                                                                 self.dynamic)
                if self._has_overflow and self.dynamic:
                    break

    def unscale_with_stashed(self,
                             model_grads,
                             stashed_master_grads,
                             master_grads,
                             scale_override=None):
        if self._has_overflow:
            return

        grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0
        if scale_override is not None:
            grads_have_scale, stashed_have_scale, out_scale = scale_override

        if LossScaler.has_fused_kernel:
            if (not LossScaler.warned_unscaling_non_fp32_grad
                and master_grads[0].dtype == torch.float16):
                print("Warning:  unscaling grads that are not FP32. "
                      "Unscaling non-fp32 grads may indicate an error. "
                      "When using Amp, you don't need to call .half() on your model.")
                # Setting this to True unconditionally allows the possibility of an escape
                # if never-before-seen non-fp32 grads are created in some later iteration.
                LossScaler.warned_unscaling_non_fp32_grad = True
            multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda,
                                 self._overflow_buf,
                                 [model_grads, stashed_master_grads, master_grads],
                                 out_scale/grads_have_scale,   # 1./scale,
                                 out_scale/stashed_have_scale, # 1.0,
                                 0) # check only arg 0, aka the incoming model grads, for infs
        else:
            self.unscale_with_stashed_python(model_grads,
                                             stashed_master_grads,
                                             master_grads,
                                             out_scale/grads_have_scale,
                                             out_scale/stashed_have_scale)

        # Defer to update_scale
        # If the fused kernel is available, we only need one D2H memcopy and sync.
        # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
        #     self._has_overflow = self._overflow_buf.item()

    def clear_overflow_state(self):
        self._has_overflow = False
        if self.has_fused_kernel:
            self._overflow_buf.zero_()

    # Separate so unscale() can be called more that once before updating.
    def update_scale(self):
        # If the fused kernel is available, we only need one D2H memcopy and sync.
        if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow:
            self._has_overflow = self._overflow_buf.item()

        if self._has_overflow and self.dynamic:
            should_skip = True
            if(self._min_loss_scale):
                self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
            else:
                self._loss_scale = self._loss_scale/2.
            self._unskipped = 0
        else:
            should_skip = False
            self._unskipped += 1

        if self._unskipped == self._scale_seq_len and self.dynamic:
            self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
            self._unskipped = 0

        return should_skip
