# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from accelerate import Accelerator, DistributedType


class LocalSGD:
    """
    A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
    on each device, and averages model weights every K synchronization step.

    It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
    this is a simple implementation that cannot support scenarios such as model parallelism.


    Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
    back to at least:

    Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
    arXiv:1606.07365.](https://arxiv.org/abs/1606.07365)

    We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).

    Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
    Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767)

    """

    def __enter__(self):
        if self.enabled:
            self.model_sync_obj = self.model.no_sync()
            self.model_sync_obj.__enter__()

        return self

    def __exit__(self, type, value, tb):
        if self.enabled:
            # Average all models on exit
            self._sync_and_avg_model_params()
            self.model_sync_obj.__exit__(type, value, tb)

    def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
        """
        Constructor.

        Args:
            model (`torch.nn.Module):
                The model whose parameters we need to average.
            accelerator (`Accelerator`):
                Accelerator object.
            local_sgd_steps (`int`):
                A number of local SGD steps (before model parameters are synchronized).
            enabled (`bool):
                Local SGD is disabled if this parameter set to `False`.
        """
        if accelerator.distributed_type not in [
            DistributedType.NO,
            DistributedType.MULTI_CPU,
            DistributedType.MULTI_GPU,
            DistributedType.MULTI_XPU,
            DistributedType.MULTI_MLU,
            DistributedType.MULTI_HPU,
            DistributedType.MULTI_SDAA,
            DistributedType.MULTI_MUSA,
            DistributedType.MULTI_NPU,
        ]:
            raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
        self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
        self.num_steps = 0
        if self.enabled:
            self.accelerator = accelerator
            self.model = model
            self.local_sgd_steps = local_sgd_steps

    def step(self):
        """
        This function makes a "step" and synchronizes model parameters if necessary.
        """
        self.num_steps += 1
        if not self.enabled:
            return

        if self.num_steps % self.local_sgd_steps == 0:
            self._sync_and_avg_model_params()

    def _sync_and_avg_model_params(self):
        """
        Synchronize + Average model parameters across all GPUs
        """

        self.accelerator.wait_for_everyone()
        with self.accelerator.autocast():
            for param in self.model.parameters():
                param.data = self.accelerator.reduce(param.data, reduction="mean")
