import os
import sys
import time
import requests
from pathlib import Path
from huggingface_hub import snapshot_download

# 可以同时放 safetensors 和 bin，实际下载时自动跳过不存在的
FILES_TO_DOWNLOAD = [
    "config.json",
    "pytorch_model.bin",
    "model.safetensors",
    "tokenizer.json",
    "tokenizer_config.json",
    "special_tokens_map.json"
]

# ========== 下载核心 ========== #
def download_file(url, save_path):
    try:
        headers = {}
        downloaded = 0
        if save_path.exists():
            downloaded = save_path.stat().st_size
            headers['Range'] = f'bytes={downloaded}-'

        print(f"⬇️ 下载: {url} (从 {downloaded} 字节续传)")

        with requests.get(url, stream=True, timeout=60, headers=headers) as r:
            if r.status_code == 416:
                print(f"✅ 文件已完整下载: {save_path}")
                return
            r.raise_for_status()
            total = int(r.headers.get('content-length', 0))
            print(f"📦 剩余大小: {total / (1024 * 1024):.2f} MB")

            mode = 'ab' if downloaded > 0 else 'wb'
            with open(save_path, mode) as f:
                for chunk in r.iter_content(chunk_size=1024 * 1024):  # 1MB
                    if chunk:
                        f.write(chunk)

        print(f"✅ 成功: {save_path}")
    except Exception as e:
        print(f"❌ 下载失败: {url}\n   错误: {e}")
        raise e

def download_with_retry(url, save_path, retries=5):
    for attempt in range(retries):
        try:
            download_file(url, save_path)
            return
        except Exception:
            if attempt < retries - 1:
                print(f"⚠️ 第 {attempt+1} 次失败，15 秒后重试...")
                time.sleep(15)
            else:
                print(f"❌ 多次尝试仍失败: {url}")
                raise

# ========== 两种下载模式 ========== #
def download_with_requests(repo_id, save_path, mirror_base):
    print(f"🌐 使用镜像 {mirror_base} 下载...")
    output_dir = Path(save_path)
    output_dir.mkdir(parents=True, exist_ok=True)

    for file in FILES_TO_DOWNLOAD:
        url = f"{mirror_base}/{repo_id}/resolve/main/{file}"
        save_file = output_dir / file
        if save_file.exists():
            print(f"✅ 已存在，跳过: {file}")
            continue
        try:
            download_with_retry(url, save_file)
        except:
            print(f"🚫 跳过该文件（可能镜像不包含）: {file}")

def download_with_snapshot(repo_id, save_path):
    print(f"🌐 使用 snapshot_download 从 huggingface.co 下载完整模型...")
    start = time.time()
    snapshot_download(
        repo_id=repo_id,
        local_dir=save_path,
        local_dir_use_symlinks=False,
        resume_download=True
    )
    print(f"✅ snapshot_download 完成！耗时 {time.time() - start:.1f} 秒")

# ========== 主入口 ========== #
def download_model(repo_id, save_path, hf_mirror_base=None):
    print(f"\n⬇️ 开始下载模型: {repo_id}")
    print(f"📁 保存路径: {save_path}\n")

    try:
        if hf_mirror_base:
            download_with_requests(repo_id, save_path, hf_mirror_base)
        else:
            download_with_snapshot(repo_id, save_path)
    except Exception as e:
        print("❌ 下载失败")
        print(e)
        sys.exit(1)

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("Usage: python download_model.py <repo_id> <save_path> [hf_mirror_base]")
        sys.exit(1)

    repo_id = sys.argv[1]
    save_path = sys.argv[2]
    hf_mirror = sys.argv[3] if len(sys.argv) > 3 else None

    download_model(repo_id, save_path, hf_mirror)
