import os 


def fsdp_auto_wrap_policy (model ,transformer_layer_name ):
    import functools 

    from torch .distributed .fsdp .wrap import _or_policy ,lambda_auto_wrap_policy ,transformer_auto_wrap_policy 

    from peft .tuners import PrefixEncoder ,PromptEmbedding ,PromptEncoder 

    def lambda_policy_fn (module ):
        if (
        len (list (module .named_children ()))==0 
        and getattr (module ,"weight",None )is not None 
        and module .weight .requires_grad 
        ):
            return True 
        return False 

    lambda_policy =functools .partial (lambda_auto_wrap_policy ,lambda_fn =lambda_policy_fn )
    transformer_wrap_policy =functools .partial (
    transformer_auto_wrap_policy ,
    transformer_layer_cls =(
    PrefixEncoder ,
    PromptEncoder ,
    PromptEmbedding ,
    transformer_layer_name ,



    ),
    )

    auto_wrap_policy =functools .partial (_or_policy ,policies =[lambda_policy ,transformer_wrap_policy ])
    return auto_wrap_policy 


def use_mpi ()->None :
    global_rank =int (os .getenv ("OMPI_COMM_WORLD_RANK",0 ))
    local_rank =int (os .getenv ("OMPI_COMM_WORLD_LOCAL_RANK",0 ))
    world_size =int (os .getenv ("OMPI_COMM_WORLD_SIZE",1 ))

    os .environ ["RANK"]=str (global_rank )
    os .environ ["LOCAL_RANK"]=str (local_rank )
    os .environ ["WORLD_SIZE"]=str (world_size )
