import torch 

from torch .distributed .fsdp .api import (
MixedPrecision ,
)


fpSixteen =MixedPrecision (
param_dtype =torch .float16 ,

reduce_dtype =torch .float32 ,

buffer_dtype =torch .float16 ,
)

bfSixteen =MixedPrecision (
param_dtype =torch .bfloat16 ,

reduce_dtype =torch .float32 ,

buffer_dtype =torch .bfloat16 ,
cast_forward_inputs =True ,
)

bfSixteen_mixed =MixedPrecision (
param_dtype =torch .float32 ,
reduce_dtype =torch .float32 ,
buffer_dtype =torch .bfloat16 ,
)

fp32_policy =MixedPrecision (
param_dtype =torch .float32 ,
reduce_dtype =torch .float32 ,
buffer_dtype =torch .float32 ,
)
