# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang


import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.tensor import distribute_module
from torch.distributed.tensor.parallel import ParallelStyle
from torch.distributed.tensor.placement_types import Placement

try:
    from torch.distributed.tensor import DTensor
except (ImportError, AttributeError):
    DTensor = None


class PrepareModuleWeight(ParallelStyle):
    def __init__(self, *, layouts: Placement | None = None):
        super().__init__()
        self.layouts = layouts

    def _replicate_module_fn(
        self,
        name: str,
        module: nn.Module,
        device_mesh: DeviceMesh,
    ):
        for p_name, param in module.named_parameters():
            replicated_param = nn.Parameter(
                DTensor.from_local(param, device_mesh, [self.layouts], run_check=False),
            )
            module.register_parameter(p_name, replicated_param)

    def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
        return distribute_module(
            module,
            device_mesh,
            partition_fn=self._replicate_module_fn,
            input_fn=None,
            output_fn=None,
        )
