import gc
import io
import logging
import pickle
import shutil
import traceback
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from functools import reduce
from multiprocessing import shared_memory
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast

import numpy as np
import torch
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
import torch.nn as nn
from packaging import version
from torch.distributed import _remote_device
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.api import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
    ShardedOptimStateDictConfig,
    ShardedStateDictConfig,
)
from torch.futures import Future
from torch.nn.parallel import DistributedDataParallel as DDP

try:
    from torch.distributed.fsdp.flat_param import FlatParamHandle  
except ModuleNotFoundError:
    from torch.distributed.fsdp._flat_param import FlatParamHandle  

from olmo import util

from .aliases import PathOrStr
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
from .exceptions import OLMoCheckpointError
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import (
    SingleAccelerator,
    barrier,
    gc_cuda,
    get_fs_local_rank,
    get_global_rank,
    get_local_rank,
    get_local_world_size,
    get_world_size,
)
from .util import (
    _get_s3_client,
    default_thread_count,
    dir_is_empty,
    get_bytes_range,
    get_progress_bar,
    resource_path,
    upload,
    wait_for,
)

__all__ = [
    "save_fsdp_model_and_optim_state",
    "load_fsdp_model_and_optim_state",
    "load_fsdp_optim_state",
    "save_state_dict",
    "load_state_dict",
    "load_model_state",
    "RemoteFileSystemWriter",
    "RemoteFileSystemReader",
    "Checkpointer",
    "FullCheckpointer",
    "TorchNewStyleShardedCheckpointer",
    "TorchLegacyShardedCheckpointer",
    "LocalShardedCheckpointer",
    "build_sharded_checkpointer",
]


log = logging.getLogger(__name__)

MODEL_AND_OPTIM_FOLDER = "model_and_optim"


def save_fsdp_model_and_optim_state(
    checkpoint_dir: PathOrStr,
    fsdp_model: FSDP,
    optim: Optimizer,
    *,
    upload_to: Optional[str] = None,
    save_overwrite: bool = False,
):
    
    checkpoint_dir = Path(checkpoint_dir)
    target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
    if save_overwrite:
        if get_fs_local_rank() == 0:
            shutil.rmtree(target_dir, ignore_errors=True)
    elif not dir_is_empty(target_dir):
        raise FileExistsError(target_dir)
    barrier()
    if get_fs_local_rank() == 0:
        target_dir.mkdir(exist_ok=True, parents=True)
    barrier()
    with FSDP.state_dict_type(
        fsdp_model,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
        optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
    ):
        model_and_optim_state = {
            "model": fsdp_model.state_dict(),
            "optim": FSDP.optim_state_dict(fsdp_model, optim),
        }
        dist_cp.save_state_dict(
            model_and_optim_state,
            RemoteFileSystemWriter(
                target_dir,
                upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
                save_overwrite=save_overwrite,
            ),
        )


def load_fsdp_model_and_optim_state(
    checkpoint_dir: PathOrStr,
    fsdp_model: FSDP,
    optim: Optimizer,
    *,
    local_cache: Optional[PathOrStr] = None,
    load_optimizer_state: bool = True,
):
    
    load_path = str(checkpoint_dir).rstrip("/")
    local_cache = None if local_cache is None else Path(local_cache)
    with FSDP.state_dict_type(
        fsdp_model,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
        optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
    ):
        
        log.info("Loading model state...")
        model_state = {"model": fsdp_model.state_dict()}
        dist_cp.load_state_dict(
            model_state,
            RemoteFileSystemReader(
                f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
                local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
            ),
        )
        fsdp_model.load_state_dict(model_state["model"])

        if not load_optimizer_state:
            return

        
        log.info("Loading sharded optimizer state...")
        optim_state = load_sharded_optimizer_state_dict(
            model_state_dict=model_state["model"],
            optimizer_key="optim",
            storage_reader=RemoteFileSystemReader(
                f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
                local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
            ),
        )
        
        
        
        
        del model_state

        
        for state in optim_state["optim"]["state"].values():
            for k in state.keys():
                state[k] = state[k].cpu()
        gc_cuda()

        load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])


def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
    log.info("Flattening sharded optimizer state...")
    
    
    
    
    
    if version.parse(torch.__version__) < version.parse("2.1.0"):
        flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim)  
    else:
        flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state)  

    del optim_state
    gc_cuda()

    log.info("Loading flattened optimizer state...")

    
    
    for state in flattened_osd["state"].values():
        for k in state.keys():
            state[k] = state[k].cpu()
    gc_cuda()

    optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))


def save_state_dict(
    checkpoint_dir: PathOrStr,
    fname: str,
    state_dict: Dict[str, Any],
    *,
    upload_to: Optional[str] = None,
    save_overwrite: bool = False,
    synchronize: bool = True,
):
    
    checkpoint_dir = Path(checkpoint_dir)
    target_path = checkpoint_dir / fname
    if save_overwrite:
        target_path.unlink(missing_ok=True)
    elif target_path.is_file():
        raise FileExistsError(target_path)
    if synchronize:
        barrier()
    target_path.parent.mkdir(exist_ok=True, parents=True)
    if synchronize:
        barrier()
    torch.save(state_dict, target_path)
    if upload_to is not None:
        upload_target = f"{upload_to.rstrip('/')}/{fname}"
        log.info(f"Uploading {target_path} to {upload_target}...")
        upload(target_path, upload_target, save_overwrite=save_overwrite)


def load_state_dict(
    checkpoint_dir: PathOrStr,
    fname: str,
    *,
    local_cache: Optional[PathOrStr] = None,
    map_location: Optional[str] = None,
):
    
    if fname.endswith(".pt"):
        
        try:
            path = resource_path(
                str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
            )
            return safetensors_file_to_state_dict(path, map_location=map_location)
        except FileNotFoundError:
            pass

    path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
    return torch.load(path, map_location=map_location, weights_only=False)


def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
    
    state_dict = {"model": model.state_dict()}
    dist_cp.load_state_dict(
        state_dict,
        RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
        no_dist=True,
    )
    model.load_state_dict(state_dict["model"])


class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
    

    def __init__(
        self,
        path: PathOrStr,
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: Optional[int] = None,
        per_thread_copy_ahead: int = 10_000_000,
        upload_to: Optional[str] = None,
        save_overwrite: bool = False,
    ) -> None:
        if thread_count is not None and thread_count <= 0:
            raise ValueError("thread count must be at least 1")
        super().__init__(
            path,
            single_file_per_rank=single_file_per_rank,
            sync_files=sync_files,
            
            
            
            thread_count=thread_count or 1,
            per_thread_copy_ahead=per_thread_copy_ahead,
        )
        self.upload_to = None if upload_to is None else upload_to.rstrip("/")
        self.save_overwrite = save_overwrite

    def write_data(
        self,
        plan: dist_cp.SavePlan,
        planner: dist_cp.SavePlanner,
    ) -> Future[List[WriteResult]]:
        fut = super().write_data(plan, planner)
        if self.upload_to is not None:
            files_to_upload = set()
            for write_result in fut.wait():
                files_to_upload.add(write_result.storage_data.relative_path)

            
            if self.upload_to.startswith("s3://"):
                _get_s3_client("s3")
            elif self.upload_to.startswith("r2://"):
                _get_s3_client("r2")
            elif self.upload_to.startswith("weka://"):
                _get_s3_client("weka")

            with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
                futures = []
                for fname in files_to_upload:
                    source = self.path / fname
                    target = f"{self.upload_to}/{fname}"
                    log.info(f"Uploading {source} to {target}...")
                    futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
                for f in as_completed(futures):
                    try:
                        f.result()
                    except BaseException:
                        
                        
                        
                        raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
        return fut

    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
        super().finish(metadata, results)
        if self.upload_to is not None:
            source = self.path / ".metadata"
            target = f"{self.upload_to}/.metadata"
            log.info(f"Uploading {source} to {target}...")
            upload(source, target, save_overwrite=self.save_overwrite)


class RemoteFileSystemReader(dist_cp.StorageReader):
    

    def __init__(
        self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
    ):
        super().__init__()
        if thread_count is not None and thread_count <= 0:
            raise ValueError("thread count must be at least 1")
        self.path = str(path).rstrip("/")
        self.cache = None if local_cache is None else Path(local_cache)
        self.thread_count = thread_count or default_thread_count()
        self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
        self._metadata: Optional[Metadata] = None

    def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
        if self.cache is not None and (path := self.cache / relative_path).is_file():
            return get_bytes_range(path, offset, length)
        else:
            return get_bytes_range(f"{self.path}/{relative_path}", offset, length)

    def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
        sinfo = self.storage_data[read_item.storage_index]
        content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
        return (read_item, content)

    def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
        
        if isinstance(self.path, str):
            if self.path.startswith("s3://"):
                _get_s3_client("s3")
            elif self.path.startswith("r2://"):
                _get_s3_client("r2")
            elif self.path.startswith("weka://"):
                _get_s3_client("weka")

        with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
            read_item_content_futures = []
            for read_item in plan.items:
                read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
            read_item_content_results = []
            for f in as_completed(read_item_content_futures):
                try:
                    read_item_content_results.append(f.result())
                except BaseException:
                    
                    
                    
                    raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")

        
        for read_item, content in read_item_content_results:
            bytes = io.BytesIO(content)
            bytes.seek(0)
            if read_item.type == LoadItemType.BYTE_IO:
                planner.load_bytes(read_item, bytes)
            else:
                tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
                tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
                target_tensor = planner.resolve_tensor(read_item).detach()

                assert (
                    target_tensor.size() == tensor.size()
                ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
                target_tensor.copy_(tensor)
                planner.commit_tensor(read_item, target_tensor)

        fut: Future = Future()
        fut.set_result(None)
        return fut

    def read_metadata(self) -> Metadata:
        if self._metadata is None:
            with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
                self._metadata = pickle.load(metadata_file)
        return self._metadata

    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
        del is_coordinator
        self.storage_data = metadata.storage_data
        assert self.storage_data is not None

    def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
        return plan

    def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
        return global_plan


class Checkpointer(metaclass=ABCMeta):
    def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
        self.cfg = cfg
        self.thread_count = thread_count or default_thread_count()

    @abstractmethod
    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        train_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        
        raise NotImplementedError

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        
        raise NotImplementedError

    @contextmanager
    def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
        
        checkpoint_dir = Path(dir)
        if not dir_is_empty(checkpoint_dir):
            if self.cfg.save_overwrite:
                if get_fs_local_rank() == 0:
                    shutil.rmtree(checkpoint_dir, ignore_errors=True)
            else:
                raise FileExistsError(checkpoint_dir)
        
        
        barrier()

        
        
        checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
        if get_fs_local_rank() == 0:
            shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
            checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)

        
        
        
        
        wait_for(lambda: checkpoint_dir_tmp.exists(), "Waiting for checkpoint directory", timeout=10.0)

        barrier()

        
        yield checkpoint_dir_tmp

        barrier()

        
        
        if get_fs_local_rank() == 0:
            
            try:
                checkpoint_dir_tmp.replace(checkpoint_dir)
            except FileNotFoundError:
                
                
                
                if not checkpoint_dir.exists():
                    raise

        
        
        
        
        wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)

        barrier()

    def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
        if get_global_rank() == 0:
            log.info("Saving config...")
            self.cfg.save(config_path := Path(dir) / "config.yaml")
            if upload_to is not None:
                upload_target = f"{upload_to}/config.yaml"
                log.info(f"Uploading {config_path} to {upload_target}")
                upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)


class FullCheckpointer(Checkpointer):
    

    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        with self._temporary_wd(dir) as checkpoint_dir:
            if isinstance(dist_model, FSDP):
                with FSDP.state_dict_type(
                    dist_model,
                    state_dict_type=StateDictType.FULL_STATE_DICT,
                    state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
                    optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
                ):
                    
                    
                    model_state_dict = dist_model.state_dict()
                    self._write_model_dict(
                        model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
                    )

                    
                    optim_state_dict = FSDP.optim_state_dict(dist_model, optim)
                    self._write_optim_dict(
                        optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
                    )
            elif isinstance(dist_model, (DDP, SingleAccelerator)):
                
                
                model_state_dict = dist_model.module.state_dict()
                self._write_model_dict(
                    model_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
                )

                
                optim_state_dict = optim.state_dict()
                self._write_optim_dict(
                    optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite
                )
            else:
                log.info(
                    "`FullCheckpointer.save_checkpoint` only supported for FSDP, DDP, and SingleAccelerator distributed strategies!"
                )

            
            if get_global_rank() == 0:
                log.info("Saving trainer state...")
                save_state_dict(
                    checkpoint_dir,
                    "train.pt",
                    trainer_state,
                    upload_to=upload_to,
                    save_overwrite=self.cfg.save_overwrite,
                    synchronize=False,
                )
            
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        if isinstance(dist_model, FSDP):
            with FSDP.state_dict_type(
                dist_model,
                state_dict_type=StateDictType.FULL_STATE_DICT,
                state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
                optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
            ):
                with torch.no_grad():
                    
                    for module_name, module in dist_model.named_modules():
                        if not isinstance(module, FSDP):
                            continue
                        for param in module.params:
                            param.fill_(torch.nan)

                    
                    state_dict_to_load = load_state_dict(
                        load_path, "model.pt", local_cache=local_cache, map_location="cpu"
                    )
                    (
                        state_dict_to_load,
                        og_keys_to_new,
                    ) = dist_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)

                    for module_name, module in dist_model.named_modules():
                        if not isinstance(module, FSDP):
                            continue
                        for param in module.params:
                            assert param._is_flat_param
                            for fqn, spi in zip(param._fqns, param._shard_param_infos):
                                if not spi.in_shard:
                                    continue
                                key = f"{module_name}.{fqn}"
                                key = key.replace("_fsdp_wrapped_module.", "")
                                key = key.lstrip(".")
                                t = state_dict_to_load[key]
                                t = t.flatten()
                                param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
                                    t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
                                )

                    
                    for module_name, module in dist_model.named_modules():
                        if not isinstance(module, FSDP):
                            continue
                        for param in module.params:
                            if torch.isnan(param).any():
                                raise ValueError(
                                    f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
                                )

                
                if load_optimizer_state:
                    optim_state_dict_to_load = load_state_dict(
                        load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
                    )
                    optim_state_dict_to_load = self._make_optim_state_dict_compatible(
                        optim_state_dict_to_load,
                        og_keys_to_new,
                    )
                    gc.collect()
                    torch.cuda.empty_cache()
                    barrier()
                    for turn in range(get_local_world_size()):
                        log.info("Loading optimizer state turn %d ...", turn)
                        if turn == get_local_rank():
                            load_fsdp_optim_state(dist_model, optim, optim_state_dict_to_load)
                            gc.collect()
                            torch.cuda.empty_cache()
                        barrier()
                    del optim_state_dict_to_load
        elif isinstance(dist_model, (DDP, SingleAccelerator)):
            
            with torch.no_grad():
                state_dict_to_load = load_state_dict(
                    load_path, "model.pt", local_cache=local_cache, map_location="cpu"
                )
                dist_model.module.load_state_dict(state_dict_to_load, strict=True)

            
            if load_optimizer_state:
                optim_state_dict_to_load = load_state_dict(
                    load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
                )
                optim.load_state_dict(optim_state_dict_to_load)

            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            barrier()
        else:
            raise NotImplementedError(
                "`FullCheckpointer.restore_checkpoint` only supported for FSDP, DDP, and SingleAccelerator distributed strategies!"
            )

        
        try:
            trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
        except FileNotFoundError:
            
            trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
        barrier()
        return trainer_state

    def _write_model_dict(self, model_state_dict, checkpoint_dir, upload_to, save_overwrite):
        if get_global_rank() == 0:
            log.info("Saving model state...")
            save_state_dict(
                checkpoint_dir,
                "model.pt",
                model_state_dict,
                upload_to=upload_to,
                save_overwrite=save_overwrite,
                synchronize=False,
            )

        del model_state_dict
        barrier()

    def _write_optim_dict(self, optim_state_dict, checkpoint_dir, upload_to, save_overwrite):
        if get_global_rank() == 0:
            log.info("Saving optim state...")
            save_state_dict(
                checkpoint_dir,
                "optim.pt",
                optim_state_dict,
                upload_to=upload_to,
                save_overwrite=save_overwrite,
                synchronize=False,
            )

        del optim_state_dict
        barrier()

    def _make_optim_state_dict_compatible(
        self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
    ) -> Dict[str, Any]:
        
        
        
        if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
            id_to_fqn: Dict[int, str] = {}
            for group in optim_state_dict["param_groups"]:
                new_param_names = []
                for fqn, id in zip(group["param_names"], group["params"]):
                    fqn = fqn.replace("_fsdp_wrapped_module.", "")
                    id_to_fqn[id] = fqn
                    new_param_names.append(fqn)
                group["param_names"] = new_param_names
                group["params"] = new_param_names
            for id in list(optim_state_dict["state"].keys()):
                optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
        else:
            
            for group in optim_state_dict["param_groups"]:
                group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
                group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
                assert group["param_names"] == group["params"]
            for key in list(optim_state_dict["state"].keys()):
                optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
                    "state"
                ].pop(key)

        
        
        for og_key, new_keys in og_keys_to_new.items():
            og_state = optim_state_dict["state"].pop(og_key, None)
            if og_state is None:
                continue
            for i, new_key in enumerate(new_keys):
                if i == len(new_keys) - 1:
                    optim_state_dict["state"][new_key] = og_state
                else:
                    optim_state_dict["state"][new_key] = deepcopy(og_state)
        
        for group in optim_state_dict["param_groups"]:
            og_names = group["params"]
            new_names = []
            for og_key in og_names:
                for new_key in og_keys_to_new[og_key]:
                    new_names.append(new_key)
            group["params"] = new_names
            group["param_names"] = new_names

        return optim_state_dict

    def load_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
        device = device if device is not None else torch.device("cpu")
        model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device)  
        optim_state = None
        if load_optimizer_state:
            optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device)  
        return model_state, optim_state


class TorchNewStyleShardedCheckpointer(Checkpointer):
    

    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
        with self._temporary_wd(dir) as checkpoint_dir:
            
            save_fsdp_model_and_optim_state(
                checkpoint_dir,
                dist_model,
                optim,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            
            log.info("Saving trainer state...")
            save_state_dict(
                checkpoint_dir,
                f"train/rank{get_global_rank()}.pt",
                trainer_state,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        
        log.info("Loading model and optimizer state...")
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."

        load_fsdp_model_and_optim_state(
            load_path,
            dist_model,
            optim,
            local_cache=local_cache,
            load_optimizer_state=load_optimizer_state,
        )

        
        log.info("Loading trainer state...")
        try:
            trainer_state = load_state_dict(
                load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
            )
        except FileNotFoundError:
            
            
            trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
        barrier()
        return trainer_state


class TorchLegacyShardedCheckpointer(Checkpointer):
    

    def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None, use_shared_mem_impl: bool = False):
        super().__init__(cfg, thread_count)
        self.use_shared_mem_impl = use_shared_mem_impl

    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."
        with self._temporary_wd(dir) as checkpoint_dir:
            with FSDP.state_dict_type(
                dist_model,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
                optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
            ):
                state_dict = {
                    "model": dist_model.state_dict(),
                    "optim": FSDP.optim_state_dict(dist_model, optim),
                    **trainer_state,
                }
                save_state_dict(
                    checkpoint_dir,
                    f"rank{get_global_rank()}.pt",
                    state_dict,
                    upload_to=upload_to,
                    save_overwrite=self.cfg.save_overwrite,
                )

            
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."
        with FSDP.state_dict_type(
            dist_model,
            state_dict_type=StateDictType.SHARDED_STATE_DICT,
            state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
            optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
        ):
            
            state_dict = load_state_dict(
                load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
            )

            
            log.info("Loading model state...")
            dist_model.load_state_dict(state_dict["model"])
            del state_dict["model"]
            if load_optimizer_state:
                log.info("Loading optimizer state...")
                load_fsdp_optim_state(dist_model, optim, state_dict["optim"])
            del state_dict["optim"]

        barrier()
        return state_dict

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        assert local_cache is None, "this method currently only supports local files"
        full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
        model_state = full_state_dict.pop("model")
        optim_state = full_state_dict.pop("optim")
        return (
            model_state,
            optim_state if load_optimizer_state else None,
            full_state_dict if load_trainer_state else None,
        )

    def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
        key = tuple() if key is None else key
        if isinstance(state, (list, tuple, set)):
            for i, sub_state in enumerate(state):
                self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
        elif isinstance(state, dict):
            for name in state.keys():
                self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
        elif isinstance(state, ShardedTensor):
            self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
            return
        else:
            return

    def _get_shard_placement_and_rank_sizes(
        self, shards_metadata: List[ShardMetadata], world_size: int
    ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
        def shard_size(shard_md):
            return reduce((lambda x, y: x * y), shard_md.shard_sizes)  

        rank_sizes = [0 for _ in range(world_size)]
        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
        for shard_md in shards_metadata:
            shard_rank = cast(_remote_device, shard_md.placement).rank()
            assert shard_rank is not None
            if shard_rank >= world_size:
                raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")

            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
            rank_sizes[shard_rank] += shard_size(shard_md)

        return shard_placement, rank_sizes

    def _copy_sharded_tensor_to_shared_mem(
        self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
    ) -> Any:
        shard0_md = sharded_tensor.metadata()
        shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
            shard0_md.shards_metadata, world_size
        )

        rank_size = rank_sizes[rank]
        assert rank_size >= 0
        if rank_size == 0:
            return

        assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
        numpy_type = np.float32

        sharded_memory_name = "-".join(key + (str(rank),))

        shm = shared_memory.SharedMemory(
            create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
        )
        np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

        for local_shard in sharded_tensor.local_shards():
            shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
            assert shard_rank == rank

            src = local_shard.tensor.flatten()
            shard_offset = shard_placement[local_shard.metadata][1]

            np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()

        shm.close()

    def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
        shard_number = int(shard_filepath.name[4:-3])
        log.info("Starting unsharding shard number %d to shared memory", shard_number)

        with self._patch_sharded_tensor_load():
            shard = torch.load(shard_filepath, map_location="cpu")
            log.debug("Done loading shard number %d", shard_number)

        self._copy_sharded_tensors_to_shared_mem(
            shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
        )
        log.info("Done unsharding shard number %d to shared memory", shard_number)

    def _unshard_using_sharded_mem(
        self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
    ) -> Any:
        return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))

    def _unshard_state_using_shared_mem(
        self, state: Any, world_size: int, device: torch.device, key: Tuple
    ) -> Any:
        if isinstance(state, (list, tuple, set)):
            return state.__class__(
                self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
                for i, sub_state in enumerate(state)
            )
        elif isinstance(state, dict):
            return {
                name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
                for name in state.keys()
            }
        elif isinstance(state, ShardedTensor):
            return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
        elif isinstance(state, torch.Tensor):
            return state.to(device=device)
        else:
            return state

    def _unshard_tensor_using_shared_mem(
        self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
    ) -> torch.Tensor:
        shard0_md = sharded_tensor.metadata()

        def shard_size(shard_md):
            return reduce((lambda x, y: x * y), shard_md.shard_sizes)  

        shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
            shard0_md.shards_metadata, world_size
        )

        assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
        numpy_type = np.float32

        out = torch.empty(
            *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
        )
        dims = len(sharded_tensor.metadata().size)
        for shard_md, (rank, rank_offset) in shard_placement.items():
            if rank >= world_size:
                raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")

            sharded_memory_name = "-".join(key + (str(rank),))
            shm = shared_memory.SharedMemory(name=sharded_memory_name)

            rank_size = rank_sizes[rank]
            assert rank_size >= 0
            if rank_size == 0:
                continue

            np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

            tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
            tensor = tensor.view(shard_md.shard_sizes)

            out_narrow_view = out
            for dim in range(dims):
                out_narrow_view = out_narrow_view.narrow(
                    dim,
                    shard_md.shard_offsets[dim],
                    shard_md.shard_sizes[dim],
                )

            out_narrow_view.copy_(tensor)

            shm.close()
            shm.unlink()

        return out

    @contextmanager
    def _patch_sharded_tensor_load(self):
        

        def _rebuild_from_type_v2_monkey(func, new_type, args, state):
            ret = func(*args)
            if type(ret) is not new_type:
                ret = ret.as_subclass(new_type)

            
            
            if isinstance(ret, ShardedTensor):
                ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
                return ret

            
            
            
            
            if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
                ret.__setstate__(state)
            else:
                ret = torch._utils._set_obj_state(ret, state)
            return ret

        original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
        try:
            torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
            yield
        finally:
            torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2

    def _unshard_using_shared_memory(
        self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None
    ):
        

        input_dir = Path(input_dir)
        skip_keys = skip_keys or set()

        shard_filepaths = list(input_dir.glob("rank*.pt"))
        world_size = len(shard_filepaths)
        if world_size == 0:
            raise RuntimeError("No shards found for unsharding")

        log.info("Number of shards: %d", world_size)
        shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
        min_ram_required_estimate_gb = shard_size_gb * world_size
        log.info(
            "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
        )

        log.info("Copying sharded tensors to shared memory using multiple processes")
        
        
        
        executor = ProcessPoolExecutor(
            mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
        )
        futures = []
        for shard_filepath in shard_filepaths:
            shard_rank = int(shard_filepath.name[4:-3])

            if shard_rank >= world_size:
                raise RuntimeError(
                    f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
                )

            futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))

        for f in as_completed(futures):
            f.result()
        executor.shutdown()

        log.info("Loading a shard on the main process to be unsharded state")
        with self._patch_sharded_tensor_load():
            state = torch.load(shard_filepaths[0], map_location="cpu")

        for key in skip_keys:
            if key in state:
                del state[key]

        log.info("Unsharding from %d shards ...", world_size)
        return self._unshard_using_sharded_mem(state, world_size, device, input_dir)

    def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
        if self.use_shared_mem_impl:
            return self._unshard_using_shared_memory(input_dir, device, skip_keys)

        input_dir = Path(input_dir)
        skip_keys = skip_keys or set()

        with self._patch_sharded_tensor_load():
            
            executor = ThreadPoolExecutor()
            shards_dict = {}
            for shard_name in input_dir.glob("rank*.pt"):
                log.info("Loading %s ...", shard_name)
                shard_number = int(shard_name.name[4:-3])  
                shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
            shards = [None] * len(shards_dict)
            for rank, shard_future in shards_dict.items():
                shard = shard_future.result()
                for key in skip_keys:
                    if key in shard:
                        del shard[key]
                shards[rank] = shard
            assert all(shard is not None for shard in shards)
            executor.shutdown()
            del shards_dict

            log.info("Unsharding from %d shards ...", len(shards))

            unsharded_state_dict = self._unshard_object(shards, device=device)
            
            del shards

            return unsharded_state_dict

    def _unshard_object(self, os: List[Any], device: torch.device) -> Any:
        rank0_item = os[0]
        assert all(type(o) is type(rank0_item) for o in os)
        if isinstance(rank0_item, str):
            assert all(o == rank0_item for o in os)
            return rank0_item
        elif isinstance(rank0_item, (list, tuple, set)):
            assert all(len(o) == len(rank0_item) for o in os)
            return rank0_item.__class__(self._unshard_object(o, device=device) for o in zip(*os))
        elif isinstance(rank0_item, dict):
            assert all(o.keys() == rank0_item.keys() for o in os)
            return {key: self._unshard_object([o[key] for o in os], device=device) for key in rank0_item.keys()}
        elif isinstance(rank0_item, ShardedTensor):
            return self._gather(os, device=device)
        else:
            assert all(self._objects_are_equal(o, rank0_item) for o in os)
            return rank0_item

    def _gather(self, shards: List[ShardedTensor], device: torch.device) -> torch.Tensor:
        world_size = len(shards)
        shard0_md = shards[0].metadata()
        
        assert all(shard.metadata() == shard0_md for shard in shards)
        
        assert all(
            shard_md.placement.rank() == rank  
            for rank, shard_md in enumerate(shard0_md.shards_metadata)
        )

        def shard_size(shard_md):
            return reduce((lambda x, y: x * y), shard_md.shard_sizes)  

        rank_sizes = [0 for _ in range(world_size)]
        max_rank_size = 0
        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
        for shard_md in shard0_md.shards_metadata:
            shard_rank = cast(_remote_device, shard_md.placement).rank()
            assert shard_rank is not None

            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
            rank_sizes[shard_rank] += shard_size(shard_md)
            max_rank_size = max(max_rank_size, rank_sizes[shard_rank])

        gather_list: List[torch.Tensor] = [torch.empty((max_rank_size,)) for _ in range(world_size)]

        datas = []
        with torch.no_grad():
            for shard in shards:
                data = torch.empty(max_rank_size)

                for local_shard in shard.local_shards():
                    src = local_shard.tensor.flatten()
                    shard_offset = shard_placement[local_shard.metadata][1]
                    data[shard_offset : shard_offset + src.numel()].copy_(src)

                datas.append(data)

        
        for rank, data in enumerate(datas):
            gather_list[rank].copy_(data)

        full_size = shard0_md.size
        out = torch.empty(*full_size, dtype=shard0_md.tensor_properties.dtype, device=device)
        dims = len(full_size)
        for shard_md in shard0_md.shards_metadata:
            rank, rank_offset = shard_placement[shard_md]
            tensor = gather_list[rank]
            tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
            tensor = tensor.view(shard_md.shard_sizes)

            out_narrow_view = out
            for dim in range(dims):
                out_narrow_view = out_narrow_view.narrow(
                    dim,
                    shard_md.shard_offsets[dim],
                    shard_md.shard_sizes[dim],
                )

            out_narrow_view.copy_(tensor)

        return out

    def _objects_are_equal(self, a: Any, b: Any) -> bool:
        if type(a) is not type(b):
            return False
        if isinstance(a, np.ndarray):
            return np.array_equal(a, b)
        elif isinstance(a, torch.Tensor):
            return torch.equal(a, b)
        else:
            return a == b


@dataclass
class _LocalShardedCheckpointerMetadata(BaseConfig):
    world_size: int = field(default_factory=get_world_size)


@dataclass
class _FlatParamShard:
    full_shape: torch.Size
    shard_offsets: Tuple[int, int]
    shard_data: Optional[torch.Tensor]

    def copy_into(self, full_tensor: torch.Tensor) -> None:
        assert self.shard_data is not None
        full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
        assert self.shard_data.shape == full_tensor_shard_view.shape
        full_tensor_shard_view.copy_(self.shard_data)


class LocalShardedCheckpointer(Checkpointer):
    

    
    _FLAT_PARAM_METADATA_TO_SAVE = (
        "_fqns",
        "_shard_param_offsets",
        "_shard_indices",
        "_numels",
        "_numels_with_padding",
        "_shapes",
        "_shard_numel_padded",
        "_shard_param_infos",
    )

    def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
        
        modules = []
        for name, module in fsdp_model.named_modules():
            if isinstance(module, FSDP):
                modules.append((name, module))
        return modules

    def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
        from torch.distributed.fsdp._runtime_utils import _lazy_init

        
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        _lazy_init(fsdp_model, fsdp_model)

    def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
        if version.parse(torch.__version__) < version.parse("2.1.0"):
            return fsdp_model._handles  
        elif version.parse(torch.__version__) < version.parse("2.3.0"):
            
            if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
                return [fsdp_model._handle]  
            else:
                return []
        else:
            
            raise NotImplementedError

    @torch.no_grad()
    def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
        self._prepare_fsdp_model(fsdp_model)
        module_data = []
        for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
            handle_data = []
            for handle in self._fsdp_handles(fsdp_module):
                data: Dict[str, Any] = {}
                
                
                flat_param = handle.flat_param
                data["flat_param.data"] = flat_param.detach()
                for key in self._FLAT_PARAM_METADATA_TO_SAVE:
                    if hasattr(flat_param, key):
                        data[f"flat_param.{key}"] = getattr(flat_param, key)
                handle_data.append(data)
            module_data.append({"handles": handle_data, "name": module_fqn})
        return {"modules": module_data}

    @torch.no_grad()
    def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
        
        self._prepare_fsdp_model(fsdp_model)
        fsdp_modules = self._fsdp_modules(fsdp_model)
        assert len(model_state["modules"]) == len(fsdp_modules)
        for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
            handles = self._fsdp_handles(fsdp_module)
            assert len(handles) == len(module_data["handles"])
            for handle, data in zip(handles, module_data["handles"]):
                flat_param = handle.flat_param
                
                for key in self._FLAT_PARAM_METADATA_TO_SAVE:
                    if hasattr(flat_param, key):
                        assert getattr(flat_param, key) == data[f"flat_param.{key}"]
                
                flat_param.copy_(data["flat_param.data"])

    def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
        if get_fs_local_rank() == 0:
            log.info("Saving metadata...")
            metadata = _LocalShardedCheckpointerMetadata()
            metadata.save(metadata_path := Path(dir) / "metadata.yaml")
            if upload_to is not None and get_global_rank() == 0:
                upload_target = f"{upload_to}/metadata.yaml"
                log.info(f"Uploading {metadata_path} to {upload_target}")
                upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)

    def _load_metadata(
        self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
    ) -> _LocalShardedCheckpointerMetadata:
        metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
        return _LocalShardedCheckpointerMetadata.load(metadata_path)

    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to save a model where `distributed_strategy` is not FSDP."

        with self._temporary_wd(dir) as checkpoint_dir:
            
            
            
            
            log.info("Saving local FSDP flat params data...")
            save_state_dict(
                checkpoint_dir,
                f"model/rank{get_global_rank()}.pt",
                self._get_flat_param_state_to_save(dist_model),
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            
            log.info("Saving local optimizer state...")
            save_state_dict(
                checkpoint_dir,
                f"optim/rank{get_global_rank()}.pt",
                optim.state_dict(),
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            
            log.info("Saving trainer state...")
            save_state_dict(
                checkpoint_dir,
                f"train/rank{get_global_rank()}.pt",
                trainer_state,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            
            self._save_metadata(checkpoint_dir, upload_to=upload_to)

            
            
            
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        
        metadata = self._load_metadata(load_path, local_cache=local_cache)
        assert metadata.world_size == get_world_size()

        
        log.info("Loading local FSDP flat params data...")
        assert isinstance(
            dist_model, FSDP
        ), f"{self.__class__.__name__} is being called to load a model where `distributed_strategy` is not FSDP."

        model_state = load_state_dict(
            load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
        )
        self._load_flat_param_state(dist_model, model_state)
        del model_state

        
        if load_optimizer_state:
            log.info("Loading local optimizer state...")
            optim_state = load_state_dict(
                load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
            )
            
            
            
            
            
            
            
            for param_id in list(optim_state["state"].keys()):
                state = optim_state["state"][param_id]
                if "grad_norm_exp_avg" in state:
                    del state["grad_norm_exp_avg"]
                if len(state) == 0:
                    del optim_state["state"][param_id]
            optim.load_state_dict(optim_state)
            del optim_state

        
        log.info("Loading local trainer state...")
        trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
        barrier()
        return trainer_state

    def _iter_flat_param_shards(
        self, model_state: Dict[str, Any]
    ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
        for module_data in model_state["modules"]:
            module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
            for handle in module_data["handles"]:
                flat_data = handle["flat_param.data"]
                if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
                    
                    assert (flat_data[-num_padding:] == 0).all()
                
                
                
                if "flat_param._shard_indices" in handle:
                    
                    param_start = handle["flat_param._shard_indices"][0]
                    current_flat_index = 0
                    for relative_fqn, full_shape, (offset_start, offset_end) in zip(
                        handle["flat_param._fqns"][param_start:],
                        handle["flat_param._shapes"][param_start:],
                        handle["flat_param._shard_param_offsets"],
                    ):
                        root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
                        numel_shard = offset_end - offset_start + 1
                        flat_param_shard = _FlatParamShard(
                            full_shape=full_shape,
                            shard_offsets=(offset_start, offset_end),
                            shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
                        )
                        current_flat_index += numel_shard
                        yield root_fqn, flat_param_shard
                else:
                    
                    for relative_fqn, full_shape, shard_param_info in zip(
                        handle["flat_param._fqns"],
                        handle["flat_param._shapes"],
                        handle["flat_param._shard_param_infos"],
                    ):
                        if not shard_param_info.in_shard:
                            continue
                        root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
                        flat_param_shard = _FlatParamShard(
                            full_shape=full_shape,
                            shard_offsets=(
                                shard_param_info.intra_param_start_idx,
                                shard_param_info.intra_param_end_idx,
                            ),
                            shard_data=flat_data[
                                shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
                                + shard_param_info.numel_in_shard
                            ],
                        )
                        yield root_fqn, flat_param_shard

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        device = device or torch.device("cpu")
        metadata = self._load_metadata(load_path, local_cache=local_cache)

        
        log.info("Gathering model state dicts...")
        model_state_paths = self._gather_state_dict_paths(
            load_path, "model", metadata.world_size, local_cache=local_cache
        )

        
        log.info("Materializing full parameters...")
        full_model_state: Dict[str, torch.Tensor] = {}
        
        
        flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
        for rank, path in enumerate(model_state_paths):
            log.info(f"Loading shards from rank {rank}...")
            model_state = torch.load(path, map_location="cpu")
            for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
                if root_fqn not in full_model_state:
                    log.info(
                        f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
                    )
                    assert flat_param_shard.shard_data is not None
                    full_model_state[root_fqn] = torch.empty(
                        flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
                    )
                    
                    
                    full_model_state[root_fqn].fill_(torch.nan)
                
                full_param = full_model_state[root_fqn]
                log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
                flat_param_shard.copy_into(full_param)
                flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)

        log.info("Validating full parameters...")
        for key, tensor in full_model_state.items():
            if torch.isnan(tensor).any():
                raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")

        trainer_state: Optional[Dict[str, Any]] = None
        if load_trainer_state:
            trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)

        if not load_optimizer_state:
            return full_model_state, None, trainer_state

        log.info("Gathering optim state dicts...")
        optim_state_paths = self._gather_state_dict_paths(
            load_path, "optim", metadata.world_size, local_cache=local_cache
        )

        log.info("Materializing full optim state...")
        full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
        fqn_to_id: Dict[str, int] = {}
        id_to_fqn: Dict[int, str] = {}
        for rank, path in enumerate(optim_state_paths):
            log.info(f"Loading sharded optim state from rank {rank}...")
            optim_state = torch.load(path, map_location="cpu")

            
            
            
            if "param_groups" not in full_optim_state:
                full_optim_state["param_groups"] = optim_state["param_groups"]
            else:
                assert full_optim_state["param_groups"] == optim_state["param_groups"]

            
            if not fqn_to_id or not id_to_fqn:
                for group in full_optim_state["param_groups"]:
                    for fqn, id in zip(group["param_names"], group["params"]):
                        fqn = fqn.replace("_fsdp_wrapped_module.", "")
                        fqn_to_id[fqn] = id
                        id_to_fqn[id] = fqn

            
            for id, shard_state in optim_state["state"].items():
                fqn = id_to_fqn[id]
                flat_param_shard = flat_params_data[rank].get(fqn)  
                full_state = full_optim_state["state"][id]
                for key, shard_value in shard_state.items():
                    assert isinstance(shard_value, torch.Tensor)
                    if shard_value.shape == torch.Size([]):
                        
                        
                        assert key in ("step", "grad_norm_exp_avg")  
                        if key not in full_state:
                            full_state[key] = shard_value.to(device)
                        else:
                            assert full_state[key] == shard_value
                    else:
                        
                        
                        assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
                        if key not in full_state:
                            log.info(
                                f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
                            )
                            full_state[key] = torch.empty(
                                flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
                            )
                        full_state_value = full_state[key]

                        
                        log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
                        replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)

        
        for group in full_optim_state["param_groups"]:
            group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]

        return full_model_state, full_optim_state, trainer_state

    def _get_state_dict_path(
        self,
        load_path: PathOrStr,
        state_dict_type: str,
        rank: int,
        *,
        local_cache: Optional[PathOrStr] = None,
        progress=None,
    ) -> Tuple[int, Path]:
        fname = f"{state_dict_type}/rank{rank}.pt"
        return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)

    def _gather_state_dict_paths(
        self,
        load_path: PathOrStr,
        state_dict_type: str,
        world_size: int,
        *,
        local_cache: Optional[PathOrStr] = None,
    ) -> List[Path]:
        progress = get_progress_bar()
        with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
            futures = []
            for rank in range(world_size):
                future = executor.submit(
                    self._get_state_dict_path,
                    load_path,
                    state_dict_type,
                    rank,
                    local_cache=local_cache,
                    progress=progress,
                )
                futures.append(future)

            results: Dict[int, Path] = {}
            for future in as_completed(futures):
                rank, path = future.result()
                results[rank] = path

        return [results[rank] for rank in range(world_size)]


class OlmoCoreCheckpointer(Checkpointer):
    def save_checkpoint(
        self,
        dir: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        from olmo_core.distributed.checkpoint import (  
            save_model_and_optim_state,
        )

        with self._temporary_wd(dir) as checkpoint_dir:
            log.info("Saving model and optim state...")
            if get_fs_local_rank() == 0:
                (checkpoint_dir / "model").mkdir(exist_ok=True, parents=True)
                (checkpoint_dir / "optim").mkdir(exist_ok=True, parents=True)
                (checkpoint_dir / "train").mkdir(exist_ok=True, parents=True)

            wait_for(
                lambda: (checkpoint_dir / "model").exists(), "Waiting for checkpoint model directory", timeout=10.0
            )
            wait_for(
                lambda: (checkpoint_dir / "optim").exists(), "Waiting for checkpoint optim directory", timeout=10.0
            )
            wait_for(
                lambda: (checkpoint_dir / "train").exists(), "Waiting for checkpoint train directory", timeout=10.0
            )

            local_files_created = save_model_and_optim_state(checkpoint_dir, dist_model, optim)
            if upload_to is not None:
                for path in local_files_created:
                    path = Path(path)
                    upload_target = f"{upload_to.rstrip('/')}/{path.relative_to(checkpoint_dir)}"
                    log.info(f"Uploading {path} to {upload_target}...")
                    upload(path, upload_target, save_overwrite=self.cfg.save_overwrite)

            log.info("Saving trainer state...")
            save_state_dict(
                checkpoint_dir,
                f"train/rank{get_global_rank()}.pt",
                trainer_state,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        dist_model: nn.Module,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        from olmo_core.distributed.checkpoint import (  
            load_model_and_optim_state,
        )

        log.info("Loading model and optim state...")
        load_model_and_optim_state(load_path, dist_model, optim if load_optimizer_state else None)

        log.info("Loading trainer state...")
        try:
            trainer_state = load_state_dict(
                load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
            )
        except FileNotFoundError:
            
            
            trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)

        barrier()
        return trainer_state

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        from olmo_core.distributed.checkpoint import (  
            unshard_model_state,
            unshard_optim_state,
        )

        model_state = unshard_model_state(load_path, device=device)
        optim_state: Optional[Dict[str, Any]] = None
        train_state: Optional[Dict[str, Any]] = None
        if load_optimizer_state:
            optim_state = cast(Dict[str, Any], unshard_optim_state(load_path, device=device))
        if load_trainer_state:
            train_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
        return model_state, optim_state, train_state


def build_sharded_checkpointer(
    cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None, use_shared_mem_impl: bool = False
) -> Checkpointer:
    name = name or cfg.sharded_checkpointer
    if name == ShardedCheckpointerType.torch_new:
        return TorchNewStyleShardedCheckpointer(cfg)
    elif name == ShardedCheckpointerType.torch_legacy:
        return TorchLegacyShardedCheckpointer(cfg, use_shared_mem_impl=use_shared_mem_impl)
    elif name == ShardedCheckpointerType.local:
        return LocalShardedCheckpointer(cfg)
    elif name == ShardedCheckpointerType.olmo_core:
        return OlmoCoreCheckpointer(cfg)
    else:
        raise NotImplementedError(name)
