# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tracks the running statistics per mini-batch instead of micro-batch."""
from typing import Optional, TypeVar, cast

import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm

from .checkpoint import is_recomputing

__all__ = ["DeferredBatchNorm"]


TModule = TypeVar("TModule", bound=nn.Module)


class DeferredBatchNorm(_BatchNorm):
    """A BatchNorm layer tracks multiple micro-batches to update running
    statistics per mini-batch.
    """

    sum: Tensor
    sum_squares: Tensor

    def __init__(
        self,
        num_features: int,
        eps: float = 1e-5,
        momentum: Optional[float] = 0.1,
        affine: bool = True,
        chunks: int = 1,
    ) -> None:
        super().__init__(num_features, eps, momentum, affine, track_running_stats=True)

        self.register_buffer("sum", torch.zeros_like(self.running_mean))
        self.register_buffer("sum_squares", torch.zeros_like(self.running_var))

        self.counter = 0
        self.tracked = 0
        self.chunks = chunks

    def _check_input_dim(self, input: Tensor) -> None:
        # It's the typical _check_input_dim() implementation in PyTorch.
        if input.dim() <= 2:
            raise ValueError("expected at least 3D input (got %dD input)" % input.dim())

    def _track(self, input: Tensor) -> bool:
        """Tracks statistics of a micro-batch."""
        # Dimensions except channel. For example, (0, 2, 3) is for BatchNorm2d.
        dim = [0]
        dim.extend(range(2, input.dim()))

        with torch.no_grad():
            self.sum += input.sum(dim)
            self.sum_squares += (input**2).sum(dim)

        size = input.size().numel() // input.size(1)
        self.counter += size
        self.tracked += 1

        return self.tracked == self.chunks

    def _commit(self) -> None:
        """Updates the running statistics of a mini-batch."""
        exponential_average_factor = 0.0
        self.num_batches_tracked += 1
        if self.momentum is None:  # use cumulative moving average
            exponential_average_factor = 1.0 / float(self.num_batches_tracked)
        else:  # use exponential moving average
            exponential_average_factor = self.momentum

        mean = self.sum / self.counter
        var = self.sum_squares / self.counter - mean**2

        # Calculate the exponential moving average here.
        m = exponential_average_factor

        self.running_mean *= 1 - m
        self.running_mean += mean * m

        self.running_var *= 1 - m
        self.running_var += var * m

        self.sum.zero_()
        self.sum_squares.zero_()
        self.counter = 0
        self.tracked = 0

    def forward(self, input: Tensor) -> Tensor:  # type: ignore
        if not self.training:
            # Don't train parameters on the evaluation mode.
            return F.batch_norm(
                input,
                running_mean=self.running_mean,
                running_var=self.running_var,
                weight=self.weight,
                bias=self.bias,
                training=False,
                momentum=0.0,
                eps=self.eps,
            )

        if not is_recomputing():
            # Track a micro-batch on the training mode
            # but not under a recomputation.
            tracked_enough = self._track(input)

            # Update the running statistics for a mini-batch
            # if it has tracked enough micro-batches.
            if tracked_enough:
                self._commit()

        # Normalize a micro-batch and train the parameters.
        return F.batch_norm(
            input,
            running_mean=None,
            running_var=None,
            weight=self.weight,
            bias=self.bias,
            training=True,
            momentum=0.0,
            eps=self.eps,
        )

    @classmethod
    def convert_deferred_batch_norm(cls, module: TModule, chunks: int = 1) -> TModule:
        """Converts a :class:`nn.BatchNorm` or underlying
        :class:`nn.BatchNorm`s into :class:`DeferredBatchNorm`::

            from torchvision.models.resnet import resnet101
            from torchpipe.batchnorm import DeferredBatchNorm
            model = resnet101()
            model = DeferredBatchNorm.convert_deferred_batch_norm(model)

        """
        if isinstance(module, DeferredBatchNorm) and module.chunks is chunks:
            return cast(TModule, module)

        module_output: nn.Module = module

        if isinstance(module, _BatchNorm) and module.track_running_stats:
            module_output = DeferredBatchNorm(module.num_features, module.eps, module.momentum, module.affine, chunks)
            if module.affine:
                module_output.register_parameter("weight", module.weight)
                module_output.register_parameter("bias", module.bias)
            module_output.register_buffer("running_mean", module.running_mean)
            module_output.register_buffer("running_var", module.running_var)
            module_output.register_buffer("num_batches_tracked", module.num_batches_tracked)

        for name, child in module.named_children():
            module_output.add_module(name, cls.convert_deferred_batch_norm(child, chunks))

        return cast(TModule, module_output)
