""" 'Fast' Normalization Functions

For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.

Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)

Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import List, Optional

import torch
from torch.nn import functional as F

try:
    from apex.normalization.fused_layer_norm import fused_layer_norm_affine
    has_apex = True
except ImportError:
    has_apex = False


# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False  # defaulting to False for now


def is_fast_norm():
    return _USE_FAST_NORM


def set_fast_norm(enable=True):
    global _USE_FAST_NORM
    _USE_FAST_NORM = enable


def fast_group_norm(
    x: torch.Tensor,
    num_groups: int,
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    eps: float = 1e-5
) -> torch.Tensor:
    if torch.jit.is_scripting():
        # currently cannot use is_autocast_enabled within torchscript
        return F.group_norm(x, num_groups, weight, bias, eps)

    if torch.is_autocast_enabled():
        # normally native AMP casts GN inputs to float32
        # here we use the low precision autocast dtype
        # FIXME what to do re CPU autocast?
        dt = torch.get_autocast_gpu_dtype()
        x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)

    with torch.cuda.amp.autocast(enabled=False):
        return F.group_norm(x, num_groups, weight, bias, eps)


def fast_layer_norm(
    x: torch.Tensor,
    normalized_shape: List[int],
    weight: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    eps: float = 1e-5
) -> torch.Tensor:
    if torch.jit.is_scripting():
        # currently cannot use is_autocast_enabled within torchscript
        return F.layer_norm(x, normalized_shape, weight, bias, eps)

    if has_apex:
        return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)

    if torch.is_autocast_enabled():
        # normally native AMP casts LN inputs to float32
        # apex LN does not, this is behaving like Apex
        dt = torch.get_autocast_gpu_dtype()
        # FIXME what to do re CPU autocast?
        x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)

    with torch.cuda.amp.autocast(enabled=False):
        return F.layer_norm(x, normalized_shape, weight, bias, eps)