# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any

import torch
from torch.distributed.checkpoint.stateful import Stateful


class PowerSGDStateContainer(Stateful):
    """
    A tiny wrapper to make PowerSGDState checkpointable with DCP.

    PowerSGDState itself is picklable, but torch.distributed.checkpoint expects
    tensors or stateful objects. We serialize the state into a uint8 tensor.

    Note: The process_group is not serialized by PowerSGDState. After loading,
    callers must reassign the correct process group and re-register the hook.
    """

    def __init__(self, powersgd_state: Any | None = None) -> None:
        self.powersgd_state = powersgd_state

    def state_dict(self) -> dict[str, torch.Tensor]:
        if self.powersgd_state is None:
            return {"blob": torch.empty(0, dtype=torch.uint8)}
        data: bytes = pickle.dumps(
            self.powersgd_state, protocol=pickle.HIGHEST_PROTOCOL
        )

        return {"blob": torch.tensor(list(data), dtype=torch.uint8)}

    def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
        blob: torch.Tensor = state_dict["blob"]
        if blob.numel() == 0:
            self.powersgd_state = None
            return
        data = bytes(blob.tolist())
        loaded = pickle.loads(data)

        if self.powersgd_state is not None and hasattr(
            self.powersgd_state, "__setstate__"
        ):
            getstate = (
                loaded.__getstate__()
                if hasattr(loaded, "__getstate__")
                else loaded.__dict__
            )
            self.powersgd_state.__setstate__(getstate)
        else:
            self.powersgd_state = loaded
