# Usage of EDiT

## Initialization

```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from EDiT.main_method import patch_local_sgd_to_fsdp
patch_local_sgd_to_fsdp()
```

## Build Model

```python
from EDiT import GTAConfig, LocalSGDConfig, OuterOptimizerConfig
local_sgd_config = LocalSGDConfig(...)
outer_optim_config = OuterOptimizerConfig(...)
gta_config = GTAConfig(...)
local_sgd_kwargs = {
    "use_local_sgd": True,
    "local_sgd_config": local_sgd_config,
    "gta_config": gta_config,
    "outer_optim_config": outer_optim_config,
}
model = FSDP(
    model,
    ...,
    **local_sgd_kwargs,
)
```

## Save/Load Checkpoint

```python
FSDP.save_local_sgd_state_dict(
    model=model,
    full_state_dict=False,
    cpu_offload=True,
    save_dir=save_dir,
    ckpt_name="EDiT_ckpt",
)

FSDP.load_local_sgd_state_dict(
    model=model,
    full_state_dict=False,
    load_dir=load_dir,
    ckpt_name="EDiT_ckpt",
)
```