# MeshFlow

## Installation

```
conda create -n meshflow python=3.8
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install pyyaml rich mip
pip install flax protobuf==4.22.1
```

For DTensor of PyTorch, we will use a version in pytorch tau project.

```shell
git clone https://github.com/pytorch/tau.git & cd ./tau
git checkout 43fe1f7
python setup.py develop
```

`patched_aot_function` and `aten.split.Tensor` needed to be fixed with:

```diff
diff --git a/spmd/compiler/api.py b/spmd/compiler/api.py
index b173956..e08bb09 100644
--- a/spmd/compiler/api.py
+++ b/spmd/compiler/api.py
@@ -35,7 +35,7 @@ logger: logging.Logger = logging.getLogger(__name__)

 # patch aot_function so that we can pass the full (non-sharded) input to capture the graph
 # pyre-fixme
-functorch._src.aot_autograd.aot_function = patched_aot_function
+# functorch._src.aot_autograd.aot_function = patched_aot_function


 class TrainingPhase(Enum):
diff --git a/spmd/tensor/ops/tp_sharding_ops.py b/spmd/tensor/ops/tp_sharding_ops.py
index c744e6e..539e3e7 100644
--- a/spmd/tensor/ops/tp_sharding_ops.py
+++ b/spmd/tensor/ops/tp_sharding_ops.py
@@ -3,6 +3,7 @@
 import torch
 import torch.utils._pytree as pytree
 from typing import List
+import spmd
 from spmd.tensor.api import DTensor
 from spmd.tensor.utils import unwrap_local_tensor
 from spmd.tensor.ops.utils import unwrap_single_placement, register_impl
@@ -32,6 +33,9 @@ def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor:
 def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]:
     local_mat = pytree.tree_map(unwrap_local_tensor, self)
     mat_placement = pytree.tree_map(unwrap_single_placement, self)
+    if isinstance(mat_placement, spmd.Replicate):
+        tensor_list = local_mat.split(split_size_or_sections, dim=dim)
+        return [DTensor.from_local(tensor, self.device_mesh, [spmd.Replicate()]) for tensor in tensor_list]
     sharding_dim = mat_placement.dim
     world_size = self.device_mesh.size(dim=0)
     if dim < 0:
diff --git a/spmd/tensor/utils.py b/spmd/tensor/utils.py
index 848b993..39ec54e 100644
--- a/spmd/tensor/utils.py
+++ b/spmd/tensor/utils.py
@@ -44,10 +44,13 @@ def wrap(res: object, spec: OutputSpecType) -> object:
         assert spec is not None and isinstance(
             spec, tuple
         ), f"output spec does not match with output! Expected tuple, got {spec}"
-        return tuple(
-            dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
-            for e, s in zip(res, spec)
-        )
+        dtensor_res = []
+        for e, s in zip(res, spec):
+            if e is not None:
+                dtensor_res.append(dtensor.DTensor(e, s.mesh, s.placements, size=s.shape))
+            else:
+                dtensor_res.append(None)
+        return tuple(dtensor_res)
     else:
         # if the res contains only non tensor values, we simply return it without rewrapping
         return res
```

Then we install meshflow:

```shell
git clone https://github.com/Shenggan/meshflow.git
python setup.py develop
```

## Example

```shell
python examples/torch/test_simple.py
python examples/torch/test_model.py
torchrun --nproc_per_node 2 --master_port 26543 ./examples/torch/test_sharding_simple.py
torchrun --nproc_per_node 2 --master_port 26543 ./examples/torch/test_sharding_model.py
```

```shell
python examples/jax/test_simple.py
python examples/jax/test_model.py
mpirun -np 2 python ./examples/jax/test_sharding_simple.py
```

## Benchmark

```shell
torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/bench_torch.py
torchrun --nproc_per_node 2 --master_port 26543 ./benchmark/bench_torch_tp.py
```

```shell
mpirun -np 2 python ./benchmark/benchmark_jax.py
mpirun -np 2 python ./benchmark/benchmark_jax_alpa.py
```
