import logging
from functools import partial
from typing import Any, List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import register_fsdp_forward_method
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
from torch.utils.checkpoint import create_selective_checkpoint_contexts
from dinov3.utils import utils
logger = logging.getLogger("dinov3")
def map_modules_and_blocks(models: list[nn.ModuleDict], callable) -> None:
    for m in models:
        assert isinstance(m, nn.ModuleDict)
        for k in m.keys():
            if k == "backbone":
                assert isinstance(m[k].blocks, nn.ModuleList)
                for block_id, block in enumerate(m[k].blocks):
                    m[k].blocks[block_id] = callable(block, is_backbone_block=True)
            else:
                m[k] = callable(m[k], is_backbone_block=False)
def ac_compile_parallelize(
    trained_model: nn.ModuleDict,
    inference_only_models: List[nn.ModuleDict],
    cfg: Any,
    trained_model_process_group: Optional[dist.ProcessGroup] = None,
    inference_only_models_process_groups: Optional[List[dist.ProcessGroup]] = None,
) -> None:
    """
    Order of the wrappers:
    1/ Activation checkpointing on blocks
    2/ Compile blocks
    3/ FSDP blocks + global model
    """
    assert (
        isinstance(trained_model, nn.ModuleDict) and "backbone" in trained_model.keys()
    ), f"{trained_model} does not contain a backbone?"
    logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
    if utils.has_batchnorms(trained_model):
        raise NotImplementedError
    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
    backbone = trained_model.backbone
    if cfg.train.checkpointing:
        if cfg.train.checkpointing_full:
            _checkpointing_wrapper = checkpoint_wrapper
            logger.info("using selective checkpointing on backbone with full checkpointing policy")
        else:
            _save_list = [
                torch.ops.aten.mm.default,
                torch.ops.aten._scaled_mm.default,
                torch.ops.aten._scaled_dot_product_efficient_attention.default,
                torch.ops.aten._scaled_dot_product_flash_attention.default,
                torch.ops._c10d_functional.reduce_scatter_tensor.default,
            ]
            _checkpointing_wrapper = partial(
                checkpoint_wrapper,
                context_fn=partial(create_selective_checkpoint_contexts, _save_list),
                preserve_rng_state=True,
            )
            logger.info("using selective checkpointing on backbone with selective policy")
        for i, b in enumerate(backbone.blocks):
            backbone.blocks[i] = _checkpointing_wrapper(b)
    all_models = [trained_model] + inference_only_models
    if trained_model_process_group is None and inference_only_models_process_groups is None:
        all_pgs = [None] * len(all_models)
    elif trained_model_process_group is None:
        all_pgs = [None] + inference_only_models_process_groups
    elif inference_only_models_process_groups is None:
        all_pgs = [trained_model_process_group] + [None] * len(inference_only_models_process_groups)
    else:
        all_pgs = [trained_model_process_group] + inference_only_models_process_groups
    def wrap_compile_block(m: nn.Module, is_backbone_block: bool) -> nn.Module:
        if cfg.train.compile:
            if is_backbone_block and cfg.train.cudagraphs:
                m.compile(fullgraph=True, dynamic=False, options={"triton.cudagraphs": True})
            else:
                m.compile()
        return m
    map_modules_and_blocks(all_models, wrap_compile_block)
    world_mesh = init_device_mesh(
        "cuda",
        mesh_shape=(dist.get_world_size(),),
        mesh_dim_names=("dp",),
    )
    DTYPE_MAP = {
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "fp32": torch.float32,
    }
    mp_policy = MixedPrecisionPolicy(
        param_dtype=DTYPE_MAP[cfg.compute_precision.param_dtype],
        reduce_dtype=DTYPE_MAP[cfg.compute_precision.reduce_dtype],
    )
    for m, pg in zip(all_models, all_pgs):
        if pg is None:
            world_mesh = init_device_mesh(
                "cuda",
                mesh_shape=(dist.get_world_size(),),
                mesh_dim_names=("dp",),
            )
        else:
            world_mesh = DeviceMesh.from_group(pg, "cuda")
        fsdp_config = {"mesh": world_mesh, "mp_policy": mp_policy}
        for k in m.keys():
            if k != "backbone":
                m[k] = fully_shard(m[k], **fsdp_config, reshard_after_forward=True)
                m[k].set_reduce_scatter_divide_factor(1)
                continue
            blocks = m[k].blocks
            assert isinstance(blocks, nn.ModuleList)
            for block_id, block in enumerate(blocks):
                block_reshard: int | bool = True
                blocks[block_id] = fully_shard(block, **fsdp_config, reshard_after_forward=block_reshard)
                blocks[block_id].set_reduce_scatter_divide_factor(1)
            prev_block: FSDPState
            next_block: FSDPState
            for prev_block, next_block in zip(blocks, blocks[1:]):
                prev_block.set_modules_to_forward_prefetch([next_block])
                next_block.set_modules_to_backward_prefetch([prev_block])
            fully_shard(m.backbone, **fsdp_config, reshard_after_forward=True).set_reduce_scatter_divide_factor(1)
            register_fsdp_forward_method(m.backbone, "get_intermediate_layers")
    for model in all_models:
        model.to_empty(device="cuda")
    for model in inference_only_models:
        for k in model.keys():
            fsdp_state: FSDPState = model[k]._get_fsdp_state()
            if not fsdp_state._fsdp_param_group:
                continue
            mi = fsdp_state._fsdp_param_group.post_forward_mesh_info
            fsdp_state._lazy_init()
            fsdp_state._fsdp_param_group.post_forward_mesh_info = mi
