import unittest
import os

import torch
from torch.optim import Optimizer
import apex
from apex.multi_tensor_apply import multi_tensor_applier
from itertools import product

class RefLAMB(Optimizer):
    r"""Implements Lamb algorithm.

    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-6)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)

    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(RefLAMB, self).__init__(params, defaults)
        if multi_tensor_applier.available:
            import amp_C
            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
            # Skip buffer
            self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
            self.multi_tensor_lamb = amp_C.multi_tensor_lamb
        else:
            raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # create separate grad lists for fp32, fp16, and bf16 params
        g_all_32, g_all_16, g_all_bf16 = [], [], []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad.data)
                elif p.dtype == torch.float16:
                    g_all_16.append(p.grad.data)
                elif p.dtype == torch.bfloat16:
                    g_all_bf16.append(p.grad.data)
                else:
                    raise RuntimeError('FusedLAMB only support fp16, fp32, and bf16.')

        device = self.param_groups[0]["params"][0].device
        g_norm_32, g_norm_16, g_norm_bf16 = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
        # compute grad norm for two lists
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_32], False)[0]
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_16], False)[0]
        if len(g_all_bf16) > 0:
            g_norm_bf16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_bf16], False)[0]

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
                                                self._dummy_overflow_buf,
                                                [[g_norm_32, g_norm_16, g_norm_bf16]],
                                                False)[0]

        max_grad_norm = 1.0
        clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.grad.data *= clipped_ratio
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['m'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['v'] = torch.zeros_like(p.data)

                m_t, v_t = state['m'], state['v']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # m_t = beta1 * m + (1 - beta1) * g_t
                m_t.mul_(beta1).add_(grad, alpha=1-beta1)
                # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
                if len(g_all_16) > 0:
                    v_t.mul_(beta2)
                    v_t = v_t.to(torch.float32)
                    grad32 = grad.to(torch.float32)
                    v_t.addcmul_(grad32, grad32, value=1-beta2)
                else:
                    v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)

                # Debiasing
                m_t_hat = m_t / (1.0 - beta1 ** state['step'])
                v_t_hat = v_t / (1.0 - beta2 ** state['step'])

                update = m_t_hat / v_t_hat.sqrt().add(group['eps'])

                if group['weight_decay'] != 0:
                    update.add_(p.data, alpha=group['weight_decay'])

                trust_ratio = 1.0
                w_norm = p.data.to(torch.float32).pow(2).sum().sqrt()
                g_norm = update.pow(2).sum().sqrt()
                if w_norm > 0 and g_norm > 0:
                    trust_ratio = w_norm / g_norm

                state['w_norm'] = w_norm
                state['g_norm'] = g_norm
                state['trust_ratio'] = trust_ratio

                step_size = group['lr']

                p.data.add_(update, alpha=-step_size*trust_ratio)

        return loss

class TestLamb(unittest.TestCase):
    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
        self.max_abs_diff = max_abs_diff
        self.max_rel_diff = max_rel_diff
        self.iters = iters
        torch.cuda.manual_seed(9876)


    def tearDown(self):
        pass

    def gen_param_optim(self, tensors, lamb_option):
        ref_param = []
        tst_param = []
        for tensor in tensors:
            ref_param.append(torch.nn.Parameter(tensor.clone()))
            tst_param.append(torch.nn.Parameter(tensor.clone()))

        ref_optim = self.ref_optim(ref_param, **lamb_option)
        tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option)

        return (ref_param, tst_param, ref_optim, tst_optim)

    def gen_grad(self, ref_param, tst_param):
        for p_ref, p_tst in zip(ref_param, tst_param):
            p_ref.grad = torch.rand_like(p_ref)
            p_tst.grad = p_ref.grad

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, _ in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

    def gen_single_type_test(self, param_type=torch.float, device="cuda"):
        nelem = 18011
        tensor = torch.rand(nelem, dtype=param_type, device=device)
        weight_decay = [0, 0.01]

        for wd in weight_decay:
            lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim([tensor], lamb_option)

            if isinstance(tst_optim, apex.optimizers.FusedMixedPrecisionLamb):
                if param_type != torch.float:
                    # joseli: This parameter is usually passed into the constructor, 
                    # but I do not want to change the testing interface.
                    # As long as this parameter is set before the first call to step(), 
                    # then it should act normally.
                    tst_optim.reduced_precision_dtype = param_type
            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                torch.cuda.synchronize()
                tst_optim.step()
                torch.cuda.synchronize()
                torch.testing.assert_close(tst_param, ref_param)

class TestFusedLAMB(TestLamb):
    def __init__(self, *args, **kwargs):
        super(TestLamb, self).__init__(*args, **kwargs)
        self.ref_optim = RefLAMB
        self.tst_optim = apex.optimizers.FusedLAMB


    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    @unittest.skip("PyTorch optimizer is not numerically correct for fp16")
    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)

    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        weight_decay = [0, 0.01]

        for wd in weight_decay:
            lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
            tensors = []
            for size in sizes:
                tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim(tensors, lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                torch.testing.assert_close(tst_param, ref_param)

    def test_lamb_option(self):
        nelem = 1
        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        weight_decay = [0, 0.01]

        for wd in weight_decay:
            lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim([tensor], lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                torch.testing.assert_close(tst_param, ref_param)

class TestFusedMixedPrecisionLamb(TestLamb):
    def __init__(self, *args, **kwargs):
        super(TestLamb, self).__init__(*args, **kwargs)
        self.ref_optim = RefLAMB
        self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb


    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_bfloat16(self):
        self.iters = 4
        self.gen_single_type_test(param_type=torch.bfloat16)

    def test_half(self):
        self.iters = 1
        self.gen_single_type_test(param_type=torch.float16)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)

    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        weight_decay = [0, 0.01]

        for wd in weight_decay:
            lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
            tensors = []
            for size in sizes:
                tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim(tensors, lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                torch.testing.assert_close(tst_param, ref_param)

    def test_lamb_option(self):
        nelem = 1
        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        weight_decay = [0, 0.01]

        for wd in weight_decay:
            lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim([tensor], lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                torch.testing.assert_close(tst_param, ref_param)

if __name__ == '__main__':
    script_path = os.path.dirname(os.path.realpath(__file__))
    unittest.main()
