# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import glob
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple

import numpy as np
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForTokenClassification,
    AutoModelForVision2Seq,
    PretrainedConfig,
    PreTrainedModel,
)


def cleanup_checkpoint_shards(local_dir: str, verbose: bool = True) -> dict:
    """
    清理checkpoint目录中的分片文件和优化器状态。
    
    Args:
        local_dir: checkpoint目录路径（如 .../global_step_15/actor）
        verbose: 是否打印详细信息
    
    Returns:
        dict: 包含清理统计信息
    """
    stats = {
        "model_shards_removed": 0,
        "optim_shards_removed": 0,
        "extra_states_removed": 0,
        "bytes_freed": 0,
    }
    
    # 检查 huggingface 目录是否存在
    hf_path = os.path.join(local_dir, "huggingface")
    if not os.path.exists(hf_path):
        print(f"[WARNING] huggingface目录不存在，跳过清理: {hf_path}")
        return stats
    
    # 检查合并后的模型文件是否存在
    model_files = glob.glob(os.path.join(hf_path, "*.safetensors")) + \
                  glob.glob(os.path.join(hf_path, "*.bin"))
    if not model_files:
        print(f"[WARNING] huggingface目录中没有模型文件，跳过清理: {hf_path}")
        return stats
    
    if verbose:
        print(f"[Cleanup] 开始清理checkpoint分片: {local_dir}")
    
    # 删除模型分片
    for f in glob.glob(os.path.join(local_dir, "model_world_size_*.pt")):
        size = os.path.getsize(f)
        os.remove(f)
        stats["model_shards_removed"] += 1
        stats["bytes_freed"] += size
        if verbose:
            print(f"  删除模型分片: {os.path.basename(f)} ({size / 1e9:.2f}GB)")
    
    # 删除优化器状态
    for f in glob.glob(os.path.join(local_dir, "optim_world_size_*.pt")):
        size = os.path.getsize(f)
        os.remove(f)
        stats["optim_shards_removed"] += 1
        stats["bytes_freed"] += size
        if verbose:
            print(f"  删除优化器状态: {os.path.basename(f)} ({size / 1e9:.2f}GB)")
    
    # 删除额外状态
    for f in glob.glob(os.path.join(local_dir, "extra_state_*.pt")):
        size = os.path.getsize(f)
        os.remove(f)
        stats["extra_states_removed"] += 1
        stats["bytes_freed"] += size
        if verbose:
            print(f"  删除额外状态: {os.path.basename(f)}")
    
    if verbose:
        print(f"[Cleanup] 完成! 共释放 {stats['bytes_freed'] / 1e9:.2f}GB")
    
    return stats


def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
    if placement.is_replicate():
        return tensors[0]
    elif placement.is_partial():
        raise NotImplementedError("Partial placement is not supported yet")
    elif placement.is_shard():
        return torch.cat(tensors, dim=placement.dim).contiguous()
    else:
        raise ValueError(f"Unsupported placement: {placement}")


def upload_model_to_huggingface(local_path: str, remote_path: str):
    # Push to hugging face
    from huggingface_hub import HfApi

    api = HfApi()
    api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
    api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
    parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
    parser.add_argument("--cleanup", action="store_true", help="合并后删除分片文件和优化器状态以节省存储空间")
    args = parser.parse_args()
    local_dir: str = args.local_dir

    assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."

    # copy rank zero to find the shape of (dp, fsdp)
    rank = 0
    world_size = 0
    for filename in os.listdir(local_dir):
        match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
        if match:
            world_size = match.group(1)
            break

    assert world_size, "No model file with the proper format."

    rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
    state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
    pivot_key = sorted(state_dict.keys())[0]
    weight = state_dict[pivot_key]
    if isinstance(weight, DTensor):
        # get sharding info
        device_mesh = weight.device_mesh
        mesh = device_mesh.mesh
        mesh_dim_names = device_mesh.mesh_dim_names
    else:
        # for non-DTensor
        mesh = np.array([int(world_size)], dtype=np.int64)
        mesh_dim_names = ("fsdp",)

    print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")

    assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."

    if "tp" in mesh_dim_names:
        # fsdp * tp
        total_shards = mesh.shape[-1] * mesh.shape[-2]
        mesh_shape = (mesh.shape[-2], mesh.shape[-1])
    else:
        # fsdp
        total_shards = mesh.shape[-1]
        mesh_shape = (mesh.shape[-1],)

    print(f"Processing {total_shards} model shards in total.")
    model_state_dict_lst = []
    model_state_dict_lst.append(state_dict)
    model_state_dict_lst.extend([""] * (total_shards - 1))

    def process_one_shard(rank, model_state_dict_lst):
        model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
        state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
        model_state_dict_lst[rank] = state_dict
        return state_dict

    with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
        for rank in range(1, total_shards):
            executor.submit(process_one_shard, rank, model_state_dict_lst)

    state_dict: Dict[str, List[torch.Tensor]] = {}
    param_placements: Dict[str, List[Placement]] = {}
    keys = set(model_state_dict_lst[0].keys())
    for key in keys:
        state_dict[key] = []
        for model_state_dict in model_state_dict_lst:
            try:
                tensor = model_state_dict.pop(key)
            except Exception:
                print(f"Cannot find key {key} in rank {rank}.")

            if isinstance(tensor, DTensor):
                state_dict[key].append(tensor._local_tensor.bfloat16())
                placements = tuple(tensor.placements)
                # replicated placement at ddp dimension can be discarded
                if mesh_dim_names[0] == "ddp":
                    placements = placements[1:]

                if key not in param_placements:
                    param_placements[key] = placements
                else:
                    assert param_placements[key] == placements
            else:
                state_dict[key].append(tensor.bfloat16())

    del model_state_dict_lst

    for key in sorted(state_dict):
        if not isinstance(state_dict[key], list):
            print(f"No need to merge key {key}")
            continue

        if key in param_placements:
            # merge shards
            placements: Tuple[Shard] = param_placements[key]
            if len(mesh_shape) == 1:
                # 1-D list, FSDP without TP
                assert len(placements) == 1
                shards = state_dict[key]
                state_dict[key] = merge_by_placement(shards, placements[0])
            else:
                # 2-D list, FSDP + TP
                raise NotImplementedError("FSDP + TP is not supported yet.")
        else:
            state_dict[key] = torch.cat(state_dict[key], dim=0)

    print("Merge completed.")
    hf_path = os.path.join(local_dir, "huggingface")
    config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
    architectures: List[str] = getattr(config, "architectures", ["Unknown"])

    if "ForTokenClassification" in architectures[0]:
        AutoClass = AutoModelForTokenClassification
    elif "ForCausalLM" in architectures[0]:
        AutoClass = AutoModelForCausalLM
    elif "ForConditionalGeneration" in architectures[0]:
        AutoClass = AutoModelForVision2Seq
    else:
        raise NotImplementedError(f"Unknown architecture {architectures}.")

    with torch.device("meta"):
        model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)

    assert isinstance(model, PreTrainedModel)
    model.to_empty(device="cpu")

    print(f"Saving model to {hf_path}...")
    model.save_pretrained(hf_path, state_dict=state_dict)
    del state_dict, model

    if args.hf_upload_path:
        upload_model_to_huggingface(hf_path, args.hf_upload_path)
    
    # 清理分片文件和优化器状态
    if args.cleanup:
        print("\n" + "=" * 50)
        print("开始清理checkpoint分片文件...")
        print("=" * 50)
        cleanup_stats = cleanup_checkpoint_shards(local_dir, verbose=True)
        print(f"\n清理统计:")
        print(f"  - 模型分片删除: {cleanup_stats['model_shards_removed']} 个")
        print(f"  - 优化器状态删除: {cleanup_stats['optim_shards_removed']} 个")
        print(f"  - 额外状态删除: {cleanup_stats['extra_states_removed']} 个")
        print(f"  - 总共释放空间: {cleanup_stats['bytes_freed'] / 1e9:.2f} GB")