"""
ShardingManager for FSDP + HF Rollout (including EmbodiedHFRollout).
Manages model loading/offloading between training (actor) and inference (rollout).
"""

from loguru import logger
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

from distflow.utils.extras.device import get_torch_device
from distflow.utils.model_utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from distflow.workers.sharding_manager.base import BaseShardingManager


class FSDPHFShardingManager(BaseShardingManager):
    """
    ShardingManager for FSDP + HuggingFace Rollout.
    
    This manager handles model offloading for HF-based rollout (including EmbodiedHFRollout).
    - In __enter__: Load actor model (and embedding model if needed) to GPU before rollout
    - In __exit__: Offload actor model (and embedding model) to CPU after rollout
    
    This follows the same pattern as MultiAgentFSDPVLLMShardingManager and 
    MultiAgentFSDPSGLangShardingManager for consistency.
    """
    
    def __init__(
        self, 
        module: FSDP, 
        rollout, 
        offload_param: bool = False,
        offload_embedding: bool = False
    ):
        """
        Initialize FSDP HF Sharding Manager.
        
        Args:
            module: The FSDP-wrapped actor model (actor_module_fsdp)
            rollout: The rollout object (HFRollout or EmbodiedHFRollout)
            offload_param: Whether to offload actor model parameters
            offload_embedding: Whether to offload embedding model (for EmbodiedHFRollout)
        """
        self.module = module
        self.rollout = rollout
        self.offload_param = offload_param
        self.offload_embedding = offload_embedding
        
        # Track state
        self.is_asleep = False  # Model starts on GPU after initialization
        
        logger.info(
            f"FSDPHFShardingManager initialized: "
            f"offload_param={offload_param}, offload_embedding={offload_embedding}"
        )
    
    def __enter__(self):
        """
        Called before rollout generation.
        Load models to GPU if they were offloaded.
        """
        if not self.is_asleep:
            # Models already on GPU (first time or previous rollout didn't offload)
            return
        
        # 1. Load actor model to GPU
        if self.offload_param:
            load_fsdp_model_to_gpu(self.module)
        
        # 2. Load embedding model to GPU (for EmbodiedHFRollout)
        if self.offload_embedding:
            self.rollout.embedding_model.load_to_device()
        
        self.is_asleep = False
    
    def __exit__(self, exc_type, exc_value, traceback):
        """
        Called after rollout generation.
        Offload models to CPU to free GPU memory.
        """
        if self.is_asleep:
            # Already offloaded
            return
        
        # 1. Offload embedding model first (for EmbodiedHFRollout)
        if self.offload_embedding:
            self.rollout.embedding_model.offload_to_host()
        
        # 2. Offload actor model to CPU
        if self.offload_param:
            offload_fsdp_model_to_cpu(self.module)
        
        # 3. Clear cache
        get_torch_device().empty_cache()
        
        self.is_asleep = True

