# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import FairseqOptimizer, register_optimizer
from apex.optimizers.fused_adam import FusedAdam
import torch.nn as nn

import torch.optim
import lpmm.optim
import transformers.optimization
import bitsandbytes as bnb 


@register_optimizer('fusedadam')
class FairseqAdam(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args, params)
        self._optimizer = FusedAdam(params, **self.optimizer_config)

    @property
    def optimizer_config(self):
        """
        Return a kwarg dictionary that will be used to override optimizer
        args stored in checkpoints. This allows us to load a checkpoint and
        resume training using a different set of optimizer args, e.g., with a
        different learning rate.
        """
        return {
            'lr': self.args.lr[0],
            'betas': self.args.adam_betas,
            'eps': self.args.adam_eps,
            'weight_decay': self.args.weight_decay,
        }


@register_optimizer('adam')
class Adam(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args, params)
        self._optimizer = torch.optim.Adam(params, **self.optimizer_config)
        print(f"Use torch.optim.Adam optimizer")

    @property
    def optimizer_config(self):
        """
        Return a kwarg dictionary that will be used to override optimizer
        args stored in checkpoints. This allows us to load a checkpoint and
        resume training using a different set of optimizer args, e.g., with a
        different learning rate.
        """
        return {
            'lr': self.args.lr[0],
            'betas': self.args.adam_betas,
            'eps': self.args.adam_eps,
            'weight_decay': self.args.weight_decay,
        }


@register_optimizer('lpmm')
class LpmmAdamW(FairseqOptimizer):
    def __init__(self, args, params, model):
        super().__init__(args, params)
        self._optimizer = lpmm.optim.AdamW(
            params, 
            lr=args.lr[0],
            betas=args.adam_betas,
            eps=args.adam_eps,
            weight_decay=args.weight_decay, 
            factor_second_moment=True,
            qconfig=args.qconfig,
        )
        # skipped = 0
        # for module in model.modules():
        #     if isinstance(module, nn.Embedding):
        #         skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
        #         print(f"skipped {module}: {skipped/2**20}M params")
        #         optimizer.override_quantize_enable(module, "weight", False)
        #         print(f"lpmm: will optimize {module} in fp32")
        # print(f"skipped: {skipped/2**20}M params")
        print(f"Use lpmm.optim.AdamW optimizer")


@register_optimizer('bnb')
class BnbAdamW8bit(FairseqOptimizer):
    def __init__(self, args, params, model):
        super().__init__(args, params)
        self._optimizer = bnb.optim.AdamW8bit(
            params, 
            lr=args.lr[0], 
            betas=args.adam_betas,
            eps=args.adam_eps,
            weight_decay=args.weight_decay, 
        )
        manager = bnb.optim.GlobalOptimManager.get_instance()
        skipped = 0
        for module in model.modules():
            if isinstance(module, nn.Embedding):
                skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                print(f"skipped {module}: {skipped/2**20}M params")
                manager.register_module_override(module, "weight", {"optim_bits": 32})
                print(f"bitsandbytes: will optimize {module} in fp32")
        print(f"skipped: {skipped/2**20}M params")
        print(f"Use bnb.optim.AdamW8bit optimizer")


@register_optimizer('adafactor')
class Adafactor(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args, params)
        self._optimizer = transformers.optimization.Adafactor(
            params, 
            lr=args.lr[0],
            clip_threshold=1.0,
            decay_rate=-0.8,
            beta1=args.adam_betas[0],
            relative_step=False,
            scale_parameter=False,
            warmup_init=False
        )
        print(f"Use transformers.optimization.Adafactor optimizer")


@register_optimizer('sm3')
class SM3(FairseqOptimizer):
    def __init__(self, args, params):
        super().__init__(args, params)
        self._optimizer = lpmm.optim.SM3(
            params,
            lr=args.lr[0],
            momentum=args.adam_betas[0],
            beta=args.adam_betas[1],
            eps=args.adam_eps,
        )
        print(f"Use SM3 optimizer")