import os

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam

# FP8 recipe will be used to test precision-aware-optimizer.
from transformer_engine.pytorch.fp8 import fp8_autocast

from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig
from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig, get_megatron_optimizer
from megatron.core.process_groups_config import GradCommProcessGroups, ModelCommProcessGroups
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import is_te_min_version, is_torch_min_version
from tests.unit_tests.test_utilities import Utils
from tests.unit_tests.test_utils import _deinit_distributed, _init_distributed

try:
    # Check if FP8 block scaling is available.
    from transformer_engine.pytorch.fp8 import check_fp8_block_scaling_support

    fp8_block_scaling_available, reason_for_no_fp8_block_scaling = check_fp8_block_scaling_support()
    from transformer_engine.common.recipe import Float8BlockScaling, Format
except:
    fp8_block_scaling_available = False
    reason_for_no_fp8_block_scaling = "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
    try:
        from transformer_engine.common.recipe import DelayedScaling
    except:
        delayed_scaling_available = False


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def test_chained_optimizer():
    net = Net()
    optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01)
    optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9)
    chained_optimizer = ChainedOptimizer([optimizer_1, optimizer_2])

    # Test the chained optimizer's param groups is a reference of the underlying optimizers' param groups
    assert optimizer_1.param_groups[0]["lr"] == 0.01
    chained_optimizer.param_groups[0]["lr"] = 0.02
    assert optimizer_1.param_groups[0]["lr"] == 0.02

    # Test the chained optimizer's state is a reference of the underlying optimizers' state
    # 1. run step on optimizers, make sure there is state
    assert len(chained_optimizer.state) == 0
    input = torch.randn(1, 3, 32, 32)
    output = net(input)
    output.sum().backward()
    optimizer_1.step()
    optimizer_2.step()
    assert len(chained_optimizer.state) != 0

    # 2. check the state is a reference
    assert not list(optimizer_1.state.values())[0]["exp_avg"].is_cuda
    assert not list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda

    def to_cuda(d):
        for k, v in d.items():
            if isinstance(v, torch.Tensor):
                d[k] = v.to("cuda")
            elif isinstance(v, dict):
                to_cuda(v)
        return d

    for k, v in chained_optimizer.state.items():
        chained_optimizer.state[k] = to_cuda(v)

    assert list(optimizer_1.state.values())[0]["exp_avg"].is_cuda
    assert list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda


def test_precision_aware_fused_adam():
    try:
        from transformer_engine.pytorch.optimizers import FusedAdam
    except ImportError:
        # Older versions of TE don't have FusedAdam.
        return

    import inspect

    adam_args = inspect.signature(FusedAdam).parameters
    arg_names = ["master_weight_dtype", "exp_avg_dtype", "exp_avg_sq_dtype", "use_decoupled_grad"]
    for name in arg_names:
        if name not in adam_args:
            # Skip the test if TE doesn't support precision aware FusedAdam.
            return

    tensor = torch.rand(278011, dtype=torch.bfloat16).cuda()
    params_1 = [torch.nn.Parameter(tensor.float())]  # FP32 reference
    params_2 = [torch.nn.Parameter(tensor.clone())]  # BF16

    options = {"lr": 1, "betas": (0.1, 0.25), "eps": 1e-08, "weight_decay": 0, "amsgrad": False}

    optimizer_1 = FusedAdam(params_1, **options)
    optimizer_2 = FusedAdam(params_2, master_weights=True, use_decoupled_grad=True, **options)

    for _ in range(1000):
        for p_1, p_2 in zip(params_1, params_2):
            p_1.grad = torch.rand_like(p_1)
            p_2.decoupled_grad = p_1.grad.clone()

        optimizer_1.step()
        optimizer_2.step()

        master_params = [optimizer_2.get_unscaled_state(p, "master_param") for p in params_2]
        for p_1, p_2 in zip(params_1, master_params):
            bytes_1 = p_1.data.view(torch.uint8)
            bytes_2 = p_2.data.view(torch.uint8)
            # Make sure bit-wise matched
            assert torch.all(bytes_1 == bytes_2)

        for p_1, p_2 in zip(params_1, params_2):
            bytes_1 = p_1.data.bfloat16().view(torch.uint8)
            bytes_2 = p_2.data.view(torch.uint8)
            # Make sure bit-wise matched
            assert torch.all(bytes_1 == bytes_2)


@pytest.mark.skipif(
    not is_te_min_version("1.13.0"), reason="TE 1.13.0 is required for precision aware optimizer"
)
@pytest.mark.parametrize("precision", ['bf16', 'fp8'])
@pytest.mark.parametrize("main_params_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("main_grads_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
    # use the same dtype for exp_avg and exp_avg_sq to reduce the number of tests
    "moment_dtype",
    [torch.float32, torch.float16, torch.bfloat16, torch.uint8],
)
def test_precision_aware_optimizer(
    precision: str,
    main_params_dtype: torch.dtype,
    main_grads_dtype: torch.dtype,
    moment_dtype: torch.dtype,
):
    # Skip because bf16 optimizer states are not supported before TE 2.3.0
    if (moment_dtype == torch.bfloat16) and not is_te_min_version("2.3.0"):
        pytest.skip("bfloat16 for moment_dtype requires TE >= 2.3.0")

    if precision == 'fp8':
        if not fp8_block_scaling_available:
            fp8_recipe = "delayed"
            fp8_recipe_settings = DelayedScaling()
        else:
            fp8_recipe = "blockwise"
            fp8_recipe_settings = Float8BlockScaling(fp8_format=Format.E4M3)
    else:
        fp8_recipe = None
        fp8_recipe_settings = None

    world = int(os.getenv('WORLD_SIZE', '1'))
    rank = int(os.getenv('RANK', '0'))

    # Setup: distributed, model, mock_args.
    _init_distributed(world, rank)
    Utils.initialize_model_parallel()

    # First create baseline model with float32 optimizer states
    baseline_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
    baseline_model.requires_grad_(True)
    baseline_model.weight.data.fill_(1.0)
    baseline_ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
    baseline_model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), baseline_ddp_config, baseline_model
    )
    baseline_optimizer_config = OptimizerConfig(
        optimizer='adam',
        lr=0.01,
        bf16=True,
        use_distributed_optimizer=True,
        use_precision_aware_optimizer=False,
        main_params_dtype=torch.float32,
        main_grads_dtype=torch.float32,
        exp_avg_dtype=torch.float32,
        exp_avg_sq_dtype=torch.float32,
    )
    baseline_optim = get_megatron_optimizer(baseline_optimizer_config, [baseline_model])

    # Create test model with specified dtypes for optimizer states
    test_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
    test_model.requires_grad_(True)
    test_model.weight.data.fill_(1.0)
    ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
    test_model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, test_model
    )
    test_optimizer_config = OptimizerConfig(
        optimizer='adam',
        lr=0.01,
        bf16=True,
        fp8_recipe=fp8_recipe,
        use_distributed_optimizer=True,
        use_precision_aware_optimizer=True,
        main_params_dtype=main_params_dtype,
        main_grads_dtype=main_grads_dtype,
        exp_avg_dtype=moment_dtype,
        exp_avg_sq_dtype=moment_dtype,
    )
    test_optim = get_megatron_optimizer(test_optimizer_config, [test_model])

    # Use same input for both models
    input = torch.randn(8, 100, dtype=torch.bfloat16, device='cuda')

    # Run model
    def run_model(model, input, optim, fp8_recipe, fp8_recipe_settings):
        if not fp8_recipe:
            output = model(input)
        else:
            with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe_settings):
                output = model(input)
        loss = output.sum()
        loss.backward()
        optim.step()
        return loss.item(), optim.get_grad_norm()

    # Run baseline model and test model
    baseline_loss, baseline_grad_norm = run_model(
        baseline_model, input, baseline_optim, fp8_recipe, fp8_recipe_settings
    )
    test_loss, test_grad_norm = run_model(
        test_model, input, test_optim, fp8_recipe, fp8_recipe_settings
    )

    rtol = 1e-3  # relative tolerance
    atol = 1e-5  # absolute tolerance

    # Compare grad norms - allow small difference due to precision
    rel_diff = abs(test_grad_norm - baseline_grad_norm) / (
        abs(baseline_grad_norm) + 1e-7  # avoid div by 0
    )
    abs_diff = abs(test_grad_norm - baseline_grad_norm)
    assert (
        rel_diff <= rtol or abs_diff <= atol
    ), f"Grad norm mismatch: baseline={baseline_grad_norm}, test={test_grad_norm}, rel_diff={rel_diff}, abs_diff={abs_diff}"

    # Compare losses - allow small difference due to precision
    loss_rel_diff = abs(test_loss - baseline_loss) / (abs(baseline_loss) + 1e-7)
    loss_abs_diff = abs(test_loss - baseline_loss)
    assert (
        loss_rel_diff <= rtol or loss_abs_diff <= atol
    ), f"Loss mismatch: baseline={baseline_loss}, test={test_loss}, rel_diff={loss_rel_diff}, abs_diff={loss_abs_diff}"

    # Save and reload state dict for the test model
    state_dict = test_optim.state_dict()
    test_optim.load_state_dict(state_dict)


@pytest.mark.parametrize("use_distributed_optimizer", [False, True])
@pytest.mark.parametrize("precision", ['bf16', 'fp32'])
def test_optim_sharded_state_dict(use_distributed_optimizer: bool, precision: str):
    world = int(os.getenv('WORLD_SIZE', '1'))
    rank = int(os.getenv('RANK', '0'))

    # Setup: distributed, model, mock_args.
    _init_distributed(world, rank)
    Utils.initialize_model_parallel()
    model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
    model.requires_grad_(True)
    model.weight.data.fill_(1.0)
    ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer)
    model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
    )
    for param in model.parameters():
        assert param.requires_grad

    if precision == 'bf16':
        optimizer_config = OptimizerConfig(
            optimizer='adam', bf16=True, use_distributed_optimizer=use_distributed_optimizer
        )
    elif precision == 'fp32':
        optimizer_config = OptimizerConfig(
            optimizer='adam',
            bf16=False,
            fp16=False,
            use_distributed_optimizer=use_distributed_optimizer,
        )
    optim = get_megatron_optimizer(optimizer_config, [model])

    model_sharded_state_dict = model.sharded_state_dict()
    sharded_state_dict = optim.sharded_state_dict(model_sharded_state_dict)

    if 'optimizer' in sharded_state_dict and 'state' in sharded_state_dict['optimizer']:
        assert (
            'common_step' not in sharded_state_dict['optimizer']['state']
            or sharded_state_dict['optimizer']['state']['common_step'] is not None
        ), "Found 'optimizer.state.common_step=None' in sharded state dict."


def test_optimizer_reload_model_params():
    world = int(os.getenv('WORLD_SIZE', '1'))
    rank = int(os.getenv('RANK', '0'))
    _init_distributed(world, rank)
    Utils.initialize_model_parallel()

    model = Net().bfloat16().cuda()
    # Initial values of model params are 1.
    for param in model.parameters():
        param.data.fill_(1.0)
    ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
    model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
    )
    optimizer_config = OptimizerConfig(optimizer='adam', bf16=True, use_distributed_optimizer=True)
    optim = get_megatron_optimizer(optimizer_config, [model])

    # Set all model params to 2.
    for param in model.parameters():
        param.data.fill_(2.0)

    # Although model params are 2 now, but we haven't called reload_model_params() yet, so
    # main_params should be 1.
    for group in optim.param_groups:
        for main_param in group['params']:
            assert main_param.dtype == torch.float32
            torch.testing.assert_close(
                main_param, torch.empty_like(main_param).fill_(1.0), atol=0, rtol=0
            )

    # Copy model params to main_params, so main_params should be 2 now.
    optim.reload_model_params()
    for group in optim.param_groups:
        for main_param in group['params']:
            assert main_param.dtype == torch.float32
            torch.testing.assert_close(
                main_param, torch.empty_like(main_param).fill_(2.0), atol=0, rtol=0
            )

    # Create a new state_dict with all params set to 3.
    state_dict = model.state_dict()
    new_state_dict = {}
    for name, param in state_dict.items():
        new_state_dict[name] = torch.empty_like(param).fill_(3.0)

    # Initialize main_params with the new state_dict, so main_params should be 3 now, but model
    # params should still be 2.
    optim.reload_model_params(new_state_dict)
    for param in model.parameters():
        torch.testing.assert_close(param, torch.empty_like(param).fill_(2.0), atol=0, rtol=0)
    for group in optim.param_groups:
        for main_param in group['params']:
            assert main_param.dtype == torch.float32
            torch.testing.assert_close(
                main_param, torch.empty_like(main_param).fill_(3.0), atol=0, rtol=0
            )


@pytest.mark.skipif(
    not is_torch_min_version("2.4.0"),
    reason="torch.distributed.init_device_mesh requires torch >= 2.4.0",
)
@pytest.mark.parametrize(
    "world_size, tp_size, cp_size, dp_size",
    [
        (1, 1, 1, 1),  # Single GPU, no parallelism
        (2, 1, 2, 1),  # 2 GPUs, 1 TP, 2 CP
        (2, 2, 1, 1),  # 2 GPUs, 2 TP, 1 CP
        (8, 8, 1, 1),  # 8 GPUs, 8 TP, 1 CP
        (8, 2, 4, 1),  # 8 GPUs, 2 TP, 4 CP
        (8, 4, 2, 1),  # 8 GPUs, 4 TP, 2 CP
        (8, 1, 1, 8),  # 8 GPUs, 1 TP, 1 CP, 8 DP
        (8, 2, 1, 4),  # 8 GPUs, 2 TP, 1 CP, 4 DP
        (8, 2, 2, 2),  # 8 GPUs, 2 TP, 2 CP, 2 DP
    ],
)
def test_get_megatron_optimizer_with_custom_process_groups(world_size, tp_size, cp_size, dp_size):
    """
    Test that get_megatron_optimizer works correctly with custom process groups
    provided via grad_comm_pgs and model_comm_pgs parameters.
    """
    # Skip if world size doesn't match available GPUs
    actual_world_size = torch.cuda.device_count()
    if actual_world_size != world_size:
        pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}")

    # Initialize model parallel with default settings first
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=tp_size, context_parallel_size=cp_size
    )

    # Create device mesh for custom process groups
    device_mesh = torch.distributed.init_device_mesh(
        "cuda", (1, dp_size, 1, cp_size, tp_size), mesh_dim_names=("pp", "dp", "ep", "cp", "tp")
    )

    # Create custom process groups from device mesh
    dp_group = device_mesh.get_group(mesh_dim="dp")
    cp_group = device_mesh.get_group(mesh_dim="cp")
    tp_group = device_mesh.get_group(mesh_dim="tp")
    pp_group = device_mesh.get_group(mesh_dim="pp")

    # Create dp_cp group
    dp_cp_mesh = device_mesh["dp", "cp"]
    dp_cp_group = dp_cp_mesh._flatten().get_group()

    # Create model parallel group (tp + pp)
    mp_mesh = device_mesh["pp", "tp"]
    mp_group = mp_mesh._flatten().get_group()

    # Create process group configurations
    grad_comm_pgs = GradCommProcessGroups()
    grad_comm_pgs.dp = dp_group
    grad_comm_pgs.dp_cp = dp_cp_group
    grad_comm_pgs.expt_dp = None  # Not using expert parallelism in this test

    model_comm_pgs = ModelCommProcessGroups()
    model_comm_pgs.tp = tp_group
    model_comm_pgs.cp = cp_group
    model_comm_pgs.pp = pp_group
    model_comm_pgs.mp = mp_group
    model_comm_pgs.tp_ep_pp = None  # Not using expert parallelism in this test

    # Create a simple model for testing
    model = torch.nn.Linear(100, 100, bias=False, device='cuda')
    model.requires_grad_(True)
    model.weight.data.fill_(1.0)
    ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
    model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
    )
    for param in model.parameters():
        assert param.requires_grad
    model_chunks = [model]

    # Create optimizer config
    optimizer_config = OptimizerConfig(
        optimizer='adam',
        lr=0.001,
        weight_decay=0.01,
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_eps=1e-8,
    )

    # Test 1: Create optimizer with custom process groups
    optimizer = get_megatron_optimizer(
        config=optimizer_config,
        model_chunks=model_chunks,
        use_gloo_process_groups=False,  # Required when using custom process groups
        grad_comm_pgs=grad_comm_pgs,
        model_comm_pgs=model_comm_pgs,
    )

    # Verify optimizer was created successfully
    assert optimizer is not None, "Optimizer should not be None"
    assert hasattr(optimizer, 'param_groups'), "Optimizer should have param_groups"
    assert len(optimizer.param_groups) > 0, "Optimizer should have at least one parameter group"

    # Test 2: Verify optimizer can perform forward and backward pass
    input_tensor = torch.randn(32, 100, device='cuda', requires_grad=True)
    output = model(input_tensor)
    loss = output.sum()
    loss.backward()

    # Test 3: Optimizer step should work
    optimizer.zero_grad()
    output = model(input_tensor)
    loss = output.sum()
    loss.backward()

    # Store original parameters
    original_weight = model.module.weight.data.clone()
    original_bias = model.module.bias.data.clone() if model.module.bias is not None else None

    # Perform optimizer step
    optimizer.step()

    # Verify parameters were updated
    assert not torch.equal(
        model.module.weight.data, original_weight
    ), "Weight should be updated after optimizer step"
    if model.module.bias is not None:
        assert not torch.equal(
            model.module.bias.data, original_bias
        ), "Bias should be updated after optimizer step"

    # Test 4: Compare with default process groups optimizer (if world_size allows)
    if world_size == 1:  # Only test on single GPU to avoid complex setup
        # Create optimizer with default process groups
        default_optimizer = get_megatron_optimizer(
            config=optimizer_config, model_chunks=model_chunks
        )

        # Both optimizers should have the same structure
        assert len(optimizer.param_groups) == len(
            default_optimizer.param_groups
        ), "Custom and default optimizers should have same number of parameter groups"


def test_get_megatron_optimizer_custom_process_groups_validation():
    """
    Test validation logic for custom process groups in get_megatron_optimizer.
    """
    Utils.initialize_model_parallel(tensor_model_parallel_size=1)

    # Create a simple model
    model = torch.nn.Linear(100, 100, bias=False, device='cuda')
    model.requires_grad_(True)
    model.weight.data.fill_(1.0)
    ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
    model = DistributedDataParallel(
        TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
    )
    for param in model.parameters():
        assert param.requires_grad
    model_chunks = [model]
    optimizer_config = OptimizerConfig(optimizer='adam', lr=0.001)

    # Test 1: Both grad_comm_pgs and model_comm_pgs must be provided together
    grad_comm_pgs = GradCommProcessGroups()
    grad_comm_pgs.dp = torch.distributed.new_group()

    with pytest.raises(
        ValueError, match="Grad and model comm process groups must be provided or both must be None"
    ):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            grad_comm_pgs=grad_comm_pgs,
            model_comm_pgs=None,  # Missing model_comm_pgs
        )

    # Test 2: Missing dp process group in grad_comm_pgs
    grad_comm_pgs_no_dp = GradCommProcessGroups()
    # Missing required 'dp' group
    model_comm_pgs = ModelCommProcessGroups()

    with pytest.raises(ValueError, match="dp process group is required"):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            grad_comm_pgs=grad_comm_pgs_no_dp,
            model_comm_pgs=model_comm_pgs,
        )

    # Test 3: Missing expt_dp attribute in grad_comm_pgs
    grad_comm_pgs_no_expt_dp = GradCommProcessGroups()
    grad_comm_pgs_no_expt_dp.dp = torch.distributed.new_group()
    # Missing required 'expt_dp' attribute

    with pytest.raises(AssertionError, match="expt_dp process group is required"):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            grad_comm_pgs=grad_comm_pgs_no_expt_dp,
            model_comm_pgs=model_comm_pgs,
        )

    # Test 4: Missing mp attribute in model_comm_pgs
    grad_comm_pgs_complete = GradCommProcessGroups()
    grad_comm_pgs_complete.dp = torch.distributed.new_group()
    grad_comm_pgs_complete.expt_dp = None  # Explicitly set to None as allowed
    model_comm_pgs_no_mp = ModelCommProcessGroups()
    # Missing required 'mp' attribute

    with pytest.raises(AssertionError, match="mp process group is required"):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            grad_comm_pgs=grad_comm_pgs_complete,
            model_comm_pgs=model_comm_pgs_no_mp,
        )

    # Test 5: Missing tp_ep_pp attribute in model_comm_pgs
    model_comm_pgs_no_tp_ep_pp = ModelCommProcessGroups()
    model_comm_pgs_no_tp_ep_pp.mp = None  # Explicitly set to None as allowed
    # Missing required 'tp_ep_pp' attribute

    with pytest.raises(AssertionError, match="tp_ep_pp process group is required"):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            grad_comm_pgs=grad_comm_pgs_complete,
            model_comm_pgs=model_comm_pgs_no_tp_ep_pp,
        )

    # Test 6: Gloo process groups should not be used with custom process groups
    model_comm_pgs_complete = ModelCommProcessGroups()
    model_comm_pgs_complete.mp = None  # Explicitly set to None as allowed
    model_comm_pgs_complete.tp_ep_pp = None  # Explicitly set to None as allowed

    with pytest.raises(AssertionError, match="Gloo process groups are not supported"):
        get_megatron_optimizer(
            config=optimizer_config,
            model_chunks=model_chunks,
            use_gloo_process_groups=True,  # Should be False when using custom groups
            grad_comm_pgs=grad_comm_pgs_complete,
            model_comm_pgs=model_comm_pgs_complete,
        )
