import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch import distributed as dist
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from utils.lora import apply_lora
from transformers import GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
access_token = "your_huggingface_access_token"
def init_model_and_fsdp(rank, args):
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")

    # 1) 只在 rank 0 下载；其他 rank local_files_only=True
    if rank == 0:
        base = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            use_auth_token=access_token,
        )
        dist.barrier()                     # 确保下载完
    dist.barrier()
    if rank != 0:
        base = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            local_files_only=True,         # 直接读缓存
        )

    # 2) 应用 LoRA（只影响可训练参数）
    model = apply_lora(base, r=args.lora_r, model_name=args.model_name)

    # 3) 让模型支持梯度检查点
    model.gradient_checkpointing_enable()

    # 4) 设置自动 wrap：仅 shard transformer 层
    auto_wrap_policy = transformer_auto_wrap_policy(
        transformer_layer_cls=( 
            # 常见 HuggingFace 模型的 Block
            torch.nn.TransformerEncoderLayer,
            torch.nn.modules.activation.MultiheadAttention,
        )
    )

    # 5) 混合精度配置（AMP fp16/bf16）
    mp_policy = MixedPrecision(
        param_dtype=torch.float16,
        reduce_dtype=torch.float16,
        buffer_dtype=torch.float16,
    )

    # 6) 包装 FSDP
    fsdp_model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        device_id=device,
        forward_prefetch=True,          # PyTorch 2.1+，可略提速
    )

    return fsdp_model.to(device)
