# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict

from fairseq.distributed import utils


try:
    from fairscale.optim import OSS

    _has_fairscale = True
except ImportError:
    _has_fairscale = False


def shard_(optimizer, group):
    if not _has_fairscale:
        raise ImportError(
            "\n\nPlease install the fairscale package:" "\n\n  pip install fairscale"
        )

    class FairseqOSS(OSS):
        @property
        def disable_mem_eff_fp16_loading_hack(self):
            return True

        def __getattr__(self, name):
            if name.startswith("supports") and hasattr(self.optim, name):
                return getattr(self.optim, name)
            raise AttributeError(
                "'FairseqOSS' object has no attribute {0!r}".format(name)
            )

        def broadcast_global_state_dict(
            self, state_dict: Dict[str, Any]
        ) -> Dict[str, Any]:
            """
            Broadcasts the entire state_dict to all other ranks
            each rank is responsible to load their own partition of data
            """
            return utils.broadcast_object(
                state_dict,
                src_rank=0,
                group=self.group,
            )

    torch_optimizer = optimizer.optimizer
    optim_cls = type(torch_optimizer)

    optimizer.optimizer = FairseqOSS(
        torch_optimizer.param_groups,
        optim_cls,
        group=group,
        **optimizer.optimizer_config
    )
