#!/usr/bin/env python3
"""
独立的FSDP Checkpoint合并脚本
兼容Python 3.8+
"""

import os
import re
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import torch
from tqdm import tqdm
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig

try:
    from torch.distributed.tensor import DTensor
except ImportError:
    from torch.distributed._tensor import DTensor

from torch.distributed._tensor import Placement, Shard

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def get_world_size(local_dir: str) -> int:
    """从checkpoint文件名提取world_size"""
    for filename in os.listdir(local_dir):
        match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
        if match:
            return int(match.group(1))
    raise FileNotFoundError(f"无法确定world_size: {local_dir}")


def load_rank_zero_state_dict(local_dir: str, world_size: int) -> dict:
    """加载rank 0的state_dict"""
    path = Path(local_dir) / f"model_world_size_{world_size}_rank_0.pt"
    logger.info(f"加载rank 0 state_dict: {path}")
    return torch.load(path, map_location="cpu", weights_only=False)


def extract_device_mesh_info(state_dict: dict, world_size: int) -> Tuple[np.ndarray, Tuple[str, ...]]:
    """提取DTensor的sharding信息"""
    pivot_key = sorted(list(state_dict.keys()))[0]
    weight = state_dict[pivot_key]

    if isinstance(weight, DTensor):
        device_mesh = weight.device_mesh
        mesh = device_mesh.mesh
        mesh_dim_names = device_mesh.mesh_dim_names
    else:
        mesh = np.array([world_size], dtype=np.int64)
        mesh_dim_names = ("fsdp",)

    return mesh, mesh_dim_names


def merge_by_placement(tensors: List[torch.Tensor], placement: Placement) -> torch.Tensor:
    """根据placement合并tensors"""
    if placement.is_replicate():
        return tensors[0]
    elif placement.is_shard():
        return torch.cat(tensors, dim=placement.dim).contiguous()
    raise NotImplementedError(f"不支持的placement: {placement}")


def load_and_merge_state_dicts(
    local_dir: str,
    world_size: int,
    total_shards: int,
    mesh_shape: Tuple[int, ...],
    mesh_dim_names: Tuple[str, ...]
) -> Dict[str, torch.Tensor]:
    """加载并合并所有分片的state_dict"""
    model_state_dict_lst = [None] * total_shards

    def process_one_shard(rank: int):
        path = Path(local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt"
        return torch.load(path, map_location="cpu", weights_only=False)

    logger.info(f"加载 {total_shards} 个FSDP分片...")
    with ThreadPoolExecutor(max_workers=min(32, os.cpu_count() or 4)) as executor:
        futures = [executor.submit(process_one_shard, rank) for rank in range(total_shards)]
        for rank, future in enumerate(tqdm(futures, desc="Loading shards")):
            model_state_dict_lst[rank] = future.result()

    # 合并state_dict
    state_dict = {}
    param_placements = {}

    for key in set(model_state_dict_lst[0].keys()):
        state_dict[key] = []
        for model_state_shard in model_state_dict_lst:
            tensor = model_state_shard.pop(key)
            if isinstance(tensor, DTensor):
                state_dict[key].append(tensor._local_tensor.bfloat16())
                placements = tuple(tensor.placements)
                if mesh_dim_names[0] in ("dp", "ddp"):
                    placements = placements[1:]
                if key not in param_placements:
                    param_placements[key] = placements
            else:
                state_dict[key].append(tensor.bfloat16())

    del model_state_dict_lst

    # 合并tensors
    logger.info("合并tensors...")
    for key in sorted(state_dict):
        if not isinstance(state_dict[key], list):
            continue
        if key in param_placements:
            placements = param_placements[key]
            if len(mesh_shape) == 1:
                state_dict[key] = merge_by_placement(state_dict[key], placements[0])
            else:
                raise NotImplementedError("FSDP + TP暂不支持")
        else:
            state_dict[key] = torch.cat(state_dict[key], dim=0)

    return state_dict


def save_hf_model(state_dict: Dict[str, torch.Tensor], config_path: str, target_dir: str):
    """保存为HuggingFace格式"""
    logger.info(f"保存模型到: {target_dir}")
    os.makedirs(target_dir, exist_ok=True)

    model_config = AutoConfig.from_pretrained(config_path)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16)
    model.to_empty(device="cpu")

    # 保存模型
    model.save_pretrained(target_dir, state_dict=state_dict)
    del state_dict, model

    # 复制tokenizer文件
    for fname in os.listdir(config_path):
        if fname.endswith('.json') or fname == 'merges.txt' or fname.startswith('tokenizer'):
            src = os.path.join(config_path, fname)
            dst = os.path.join(target_dir, fname)
            if os.path.isfile(src) and not os.path.exists(dst):
                import shutil
                shutil.copy2(src, dst)

    logger.info("模型保存完成")


def merge_checkpoint(checkpoint_dir: str, target_dir: str):
    """合并FSDP checkpoint为HuggingFace格式"""
    logger.info(f"开始合并checkpoint: {checkpoint_dir}")

    world_size = get_world_size(checkpoint_dir)
    logger.info(f"World size: {world_size}")

    rank_zero_state_dict = load_rank_zero_state_dict(checkpoint_dir, world_size)
    mesh, mesh_dim_names = extract_device_mesh_info(rank_zero_state_dict, world_size)
    logger.info(f"Device mesh: {mesh}, dim_names: {mesh_dim_names}")

    total_shards = mesh.shape[-1]
    mesh_shape = (mesh.shape[-1],)
    logger.info(f"Total shards: {total_shards}")

    merged_state_dict = load_and_merge_state_dicts(
        checkpoint_dir, world_size, total_shards, mesh_shape, mesh_dim_names
    )

    save_hf_model(merged_state_dict, checkpoint_dir, target_dir)
    logger.info(f"合并完成: {target_dir}")


def main():
    parser = argparse.ArgumentParser(description='合并FSDP Checkpoint')
    parser.add_argument('--checkpoint_dir', type=str, required=True)
    parser.add_argument('--target_dir', type=str, required=True)
    args = parser.parse_args()

    merge_checkpoint(args.checkpoint_dir, args.target_dir)


if __name__ == '__main__':
    main()
