import inspect

import torch
from torch.cuda.amp import GradScaler
from torch.testing._internal import common_utils
from torch.distributed.distributed_c10d import _coalescing_manager

from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase


def flat_dist_call(param_list: list[torch.Tensor], op, args):
    with _coalescing_manager(async_ops=True) as cm:
        for p in param_list:
            op(p, *args)

    cm.wait()


def get_init_weights_func():
    @torch.no_grad()
    def init_weights(m):
        if isinstance(m, torch.nn.Linear):
            m.weight.fill_(1.0)
    return init_weights

class ModelFoo(torch.nn.Module):
    def __init__(self):
        super(ModelFoo, self).__init__()
        self.linear = torch.nn.Linear(128, 128, bias = False)
        self.loss = torch.nn.MSELoss()

    def forward(self, input_tensor, gt):
        y = self.linear(input_tensor)
        loss = self.loss(y, gt)
        return loss

# A test for distributed fused Lamb optimizer: run several iterations and see if loss decreases
# There are two instances of the same test because based on `world_size` the optimizer decides what collectives operation to use. 
# If torch.distributed.get_world_size() == torch.cuda.device_count() it uses only `all_gather`.
# If torch.distributed.get_world_size() < torch.cuda.device_count() it uses both `all_gather` and `reduce_scatter`.
class NcclDistributedFusedLAMB(NcclDistributedTestBase):
    @property
    def world_size(self) -> int:
        return torch.cuda.device_count()

    @common_utils.parametrize("no_copy", [False, True])
    @common_utils.parametrize("opt_kwargs", [
        dict(overlap_reductions=True, dwu_num_blocks=2, dwu_num_chunks=2,
             fused_norm=False, fuse_scale=False, clip_after_ar=True,
             full_ar=False),
        dict(overlap_reductions=False, dwu_num_blocks=1, dwu_num_chunks=1,
             fused_norm=True, fuse_scale=True, clip_after_ar=False),
    ])
    def test_distributed_fused_lamb(self, no_copy, opt_kwargs):
        if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.reduce_scatter).args:
            self.skipTest("does not support no_copy")
        if no_copy and 'no_copy' not in inspect.getfullargspec(torch.distributed.all_gather).args:
            self.skipTest("does not support no_copy")

        assert torch.distributed.is_initialized()
        gpu_count = torch.distributed.get_world_size()

        init_scale = 100
        lr = torch.tensor(0.1).cuda()
        grad_scaler = GradScaler(init_scale=init_scale, growth_interval=1000)

        model = ModelFoo()
        model = model.cuda().half()
        model.apply(get_init_weights_func())

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        if 'full_ar' not in opt_kwargs:
            opt_kwargs['full_ar'] = gpu_count == torch.cuda.device_count()

        # Aidyn-A: not sure what parameters are the best for testing purposes, 
        # setting up whatever I think appropriate. 
        optimizer = DistributedFusedLAMB(
                optimizer_grouped_parameters, 
                lr=0.1,
                betas=(0.9, 0.9),
                eps=1e-6,
                max_grad_norm=1.0,
                dwu_group_size=gpu_count,
                dwu_num_rs_pg=1,
                dwu_num_ar_pg=1,
                dwu_num_ag_pg=1,
                use_nvlamb=False,
                set_param_views_to_flat_buffer=False,
                e5m2_allgather=False,
                **opt_kwargs
        )
        optimizer.set_global_scale(init_scale)

        optimizer._reduce_scatter_no_copy = no_copy
        optimizer._all_gather_no_copy = no_copy

        flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )

        x = torch.randn(4096, 128, dtype=torch.float16).cuda()
        y = torch.randn(4096, 128, dtype=torch.float16).cuda()

        losses = []
        for _ in range(10):
            loss = model(x, y)
            optimizer._lazy_init_stage1()
            grad_scaler.scale(loss).backward()
            optimizer._lazy_init_stage2()
            optimizer._lr = lr
            optimizer.complete_reductions()
            optimizer.set_global_scale(grad_scaler._get_scale_async())
            grad_scaler.step(optimizer)
            grad_scaler.update()
            optimizer.zero_grad(set_to_none=True)

            losses.append(loss.item())

        self.assertTrue(losses == sorted(losses, reverse=True))

common_utils.instantiate_parametrized_tests(NcclDistributedFusedLAMB)

class NcclDistributedFusedLAMB_partial_ar(NcclDistributedFusedLAMB):
    @property
    def world_size(self) -> int:
        return max(torch.cuda.device_count()-1, 1)

if __name__ == "__main__":
    common_utils.run_tests()

