"""
Apply tensor parallelism and FSDP to a model.

References:
- https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/tensor_parallel/parallelism.py
- https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
"""

import torch 
from torch import nn 
from transformers import PreTrainedModel 

from torch .distributed ._composable .fsdp import MixedPrecisionPolicy ,CPUOffloadPolicy 
from torch .distributed ._composable .fsdp .fully_shard import fully_shard 
from torch .distributed ._tensor import Replicate ,Shard 
from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import (
checkpoint_wrapper ,
)
from torch .distributed .device_mesh import DeviceMesh 
from torch .distributed .tensor .parallel import (
ColwiseParallel ,
PrepareModuleInput ,
RowwiseParallel ,
SequenceParallel ,
parallelize_module ,
)


def apply_ac (model :nn .Module ):
    """Apply activation checkpointing to the model."""
    for layer_id ,transformer_block in model .layers .named_children ():
        transformer_block =checkpoint_wrapper (transformer_block )
        model .layers .register_module (layer_id ,transformer_block )

    print ("Applied activation checkpointing to the model")


def apply_compile (model :nn .Module ):
    """
    Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
    repeated structure. Alternatively one can compile the whole model (after applying DP).
    """
    for layer_id ,transformer_block in model .layers .named_children ():
        transformer_block =torch .compile (transformer_block ,fullgraph =True )
        model .layers .register_module (layer_id ,transformer_block )

    print ("Compiling each TransformerBlock with torch.compile")


def apply_tensor_parallelism (
model :PreTrainedModel ,device_mesh :DeviceMesh ,loss_parallel :bool =False 
)->PreTrainedModel :
    model_type =model .config .model_type 
    if model_type =="llama":
        return apply_llama_tensor_parallelism (model ,device_mesh ,loss_parallel )
    else :

        return model 


def apply_llama_tensor_parallelism (
model :PreTrainedModel ,
device_mesh :DeviceMesh ,
ac_enabled :bool =False ,
compile_enabled :bool =False ,
loss_parallel :bool =False ,
cpu_offload :bool =True ,
enable_float8 :bool =False ,
pp_enabled :bool =False ,
)->PreTrainedModel :
    """
    Reference of the model architecture:
    - https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L486
    """
    if device_mesh is None :
        return model 
    dp_mesh =device_mesh ["data_parallel"]
    tp_mesh =device_mesh ["tensor_parallel"]

    if tp_mesh .size ()>1 :




        parallelize_module (
        model ,
        tp_mesh ,
        {
        "model.embed_tokens":RowwiseParallel (input_layouts =Replicate ()),
        "lm_head":ColwiseParallel (
        input_layouts =Shard (1 ),
        output_layouts =Shard (-1 )if loss_parallel else Replicate (),
        use_local_output =not loss_parallel ,
        ),
        "model.norm":SequenceParallel (),
        "model.layers.0":PrepareModuleInput (
        input_kwarg_layouts ={
        "hidden_states":Replicate (),
        "attention_mask":None ,
        "position_ids":None ,
        "past_key_value":None ,
        "output_attentions":None ,
        "use_cache":None ,
        "cache_position":None ,
        "position_embeddings":None ,
        },
        desired_input_kwarg_layouts ={
        "hidden_states":Shard (1 ),
        "attention_mask":None ,
        "position_ids":None ,
        "past_key_value":None ,
        "output_attentions":None ,
        "use_cache":None ,
        "cache_position":None ,
        "position_embeddings":None ,
        },
        use_local_output =True ,
        ),
        },
        )


        if enable_float8 :



            from torchao .float8 .float8_tensor_parallel import (
            Float8ColwiseParallel ,
            Float8RowwiseParallel ,
            PrepareFloat8ModuleInput ,
            )

            rowwise_parallel ,colwise_parallel ,prepare_module_input =(
            Float8RowwiseParallel ,
            Float8ColwiseParallel ,
            PrepareFloat8ModuleInput ,
            )
        else :
            rowwise_parallel ,colwise_parallel ,prepare_module_input =(
            RowwiseParallel ,
            ColwiseParallel ,
            PrepareModuleInput ,
            )





        for layer_id ,transformer_block in enumerate (model .model .layers ):
            layer_plan ={
            "self_attn":prepare_module_input (
            input_kwarg_layouts ={
            "hidden_states":Shard (1 ),
            "attention_mask":None ,
            "position_ids":None ,
            "past_key_value":None ,
            "output_attentions":None ,
            "use_cache":None ,
            "cache_position":None ,
            "position_embeddings":None ,
            },
            desired_input_kwarg_layouts ={
            "hidden_states":Replicate (),
            "attention_mask":None ,
            "position_ids":None ,
            "past_key_value":None ,
            "output_attentions":None ,
            "use_cache":None ,
            "cache_position":None ,
            "position_embeddings":None ,
            },
            ),
            "self_attn.q_proj":colwise_parallel (),
            "self_attn.k_proj":colwise_parallel (),
            "self_attn.v_proj":colwise_parallel (),
            "self_attn.o_proj":rowwise_parallel (output_layouts =Shard (1 )),
            "input_layernorm":SequenceParallel (),
            "mlp":prepare_module_input (
            input_layouts =(Shard (1 ),),
            desired_input_layouts =(Replicate (),),
            ),
            "mlp.gate_proj":colwise_parallel (),
            "mlp.down_proj":rowwise_parallel (output_layouts =Shard (1 )),
            "mlp.up_proj":colwise_parallel (),
            "post_attention_layernorm":SequenceParallel (),
            }


            attn_layer =transformer_block .self_attn 
            attn_layer .n_heads =attn_layer .config .num_attention_heads //tp_mesh .size ()
            attn_layer .n_kv_heads =(
            attn_layer .config .num_key_value_heads //tp_mesh .size ()
            )

            parallelize_module (
            module =transformer_block ,
            device_mesh =tp_mesh ,
            parallelize_plan =layer_plan ,
            )
        print (
        f"Applied {'Float8 'if enable_float8 else ''}"
        "Tensor Parallelism to the model"
        )

    if ac_enabled :
        apply_ac (model )
    if compile_enabled :
        apply_compile (model )
    if dp_mesh .size ()>1 :
        assert dp_mesh .ndim ==1 



        mp_policy =MixedPrecisionPolicy (
        param_dtype =torch .bfloat16 ,reduce_dtype =torch .float32 
        )

        fsdp_config ={"mesh":dp_mesh ,"mp_policy":mp_policy }
        if cpu_offload :
            fsdp_config ["offload_policy"]=CPUOffloadPolicy ()

        for layer_id ,transformer_block in enumerate (model .model .layers ):

            if pp_enabled :


                reshard_after_forward =False 
            else :


                reshard_after_forward =int (layer_id )<len (model .layers )-1 
            fully_shard (
            transformer_block ,
            **fsdp_config ,
            reshard_after_forward =reshard_after_forward ,
            )
            model .model .layers [layer_id ]=transformer_block 
        model =fully_shard (model ,**fsdp_config ,reshard_after_forward =not pp_enabled )

    return model 
