# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import shutil
import torch
import copy
import torch.distributed
from torch.distributed import init_device_mesh
from verl.utils.distributed import initialize_global_process_group
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Qwen2MoeConfig

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
    CPUOffload


def test_fsdp_ckpt():
    assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
    local_rank, rank, world_size = initialize_global_process_group()
    device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',))

    model_name = '/cpfs01/data/shared/Group-m6/menrui.mr/hf-ckpts-public/QWen2.5-Tiny-GQA-exp9-s2'
    config = Qwen2MoeConfig(num_hidden_layers=1)

    model_fp32 = AutoModelForCausalLM.from_pretrained(model_name,
                                                      torch_dtype=torch.float32,
                                                      attn_implementation='flash_attention_2')

    # Wrap model with FSDP
    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

    model_fp32 = FSDP(model_fp32,
                      use_orig_params=False,
                      device_id=torch.cuda.current_device(),
                      sharding_strategy=ShardingStrategy.FULL_SHARD,
                      mixed_precision=mixed_precision,
                      device_mesh=device_mesh)

    model_bf16 = AutoModelForCausalLM.from_pretrained(model_name,
                                                      torch_dtype=torch.bfloat16,
                                                      attn_implementation='flash_attention_2')

    # Wrap model with FSDP
    mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)

    model_bf16 = FSDP(model_bf16,
                      use_orig_params=False,
                      device_id=torch.cuda.current_device(),
                      sharding_strategy=ShardingStrategy.FULL_SHARD,
                      mixed_precision=mixed_precision,
                      device_mesh=device_mesh)
    model_bf16.to(torch.float32)

    optimizer_fp32 = torch.optim.AdamW(model_fp32.parameters(), lr=1e-4)
    lr_scheduler_fp32 = torch.optim.lr_scheduler.StepLR(optimizer_fp32, step_size=1, gamma=0.9)

    optimizer_bf16 = torch.optim.AdamW(model_bf16.parameters(), lr=1e-4)
    lr_scheduler_bf16 = torch.optim.lr_scheduler.StepLR(optimizer_bf16, step_size=1, gamma=0.9)

    # Generate sample input
    batch_size = 2
    seq_len = 32
    vocab_size = 32000
    # First input for initial update
    input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
    attention_mask1 = torch.ones_like(input_ids1)

    # Step 1: Initial update bf16 model
    outputs1 = model_bf16(input_ids=input_ids1, attention_mask=attention_mask1)
    loss1 = outputs1.logits.mean()
    loss1.backward()
    optimizer_bf16.step()
    lr_scheduler_bf16.step()
    optimizer_bf16.zero_grad()

    # Step 2: Second update fp32 model
    outputs2 = model_fp32(input_ids=input_ids1, attention_mask=attention_mask1)
    loss2 = outputs2.logits.mean()
    loss2.backward()
    optimizer_fp32.step()
    lr_scheduler_fp32.step()
    optimizer_fp32.zero_grad()

    # Second input for verification
    input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
    attention_mask2 = torch.ones_like(input_ids2)

    with torch.no_grad():
        logits_fp32 = model_fp32(input_ids=input_ids2, attention_mask=attention_mask2).logits

    with torch.no_grad():
        logits_bf16 = model_bf16(input_ids=input_ids2, attention_mask=attention_mask2).logits

    torch.testing.assert_close(logits_fp32, logits_bf16, atol=0.0, rtol=0.0)
    print("bf16/fp32 load test passed!")


if __name__ == '__main__':
    test_fsdp_ckpt()
