import os
import yaml
import requests
import argparse
from huggingface_hub import list_repo_files
from tqdm import tqdm
import time

# 1. 读取 YAML 配置文件
def load_yaml_config(yaml_file_path):
    with open(yaml_file_path, "r") as file:
        return yaml.safe_load(file)

# 2. 获取 Hugging Face 上数据集的所有文件列表
def list_dataset_files(dataset_name):
    try:
        # 通过 repo_type="dataset" 获取数据集中的所有文件
        files = list_repo_files(repo_id=dataset_name, repo_type="dataset")
        return files
    except Exception as e:
        print(f"Error listing files for dataset {dataset_name}: {e}")
        return []

def download_folder_files(yaml_config, dataset_name):
    # 从 YAML 配置文件中获取数据集文件夹的名称
    for dataset_config in tqdm(yaml_config['datasets']):
        folder_name = dataset_config['folder_name']

        # 列出数据集中的所有文件
        dataset_files = list_dataset_files(dataset_name)

        folder_files = [file for file in dataset_files if folder_name in file.split('/')]

        # 创建本地文件夹（如果不存在）
        local_folder = f"./{folder_name}"
        if not os.path.exists(local_folder):
            os.makedirs(local_folder)

        # 下载所有文件
        for file_name in folder_files:
            local_file_path = os.path.join(local_folder, os.path.basename(file_name))

            # 跳过已存在的文件
            if os.path.exists(local_file_path):
                print(f"Skipping {local_file_path} (already exists).")
                continue

            file_url = f"https://px.winniexi.us.kg/proxy/https://huggingface.co/datasets/{dataset_name}/resolve/main/{file_name}?download=true"
            download_file(file_url, local_file_path)

# 下载文件的功能（增加超时和重试机制）
def download_file(url, file_path, retries=3, timeout=30):
    for attempt in range(retries):
        try:
            print(f"Downloading {file_path} from {url} (attempt {attempt + 1}/{retries})...")
            with requests.get(url, stream=True, timeout=timeout) as response:
                response.raise_for_status()  # 确保响应是成功的

                # 写入文件
                with open(file_path, "wb") as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)

            print(f"Downloaded {file_path}")
            return  # 下载成功，退出函数
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {file_path}: {e}")
            if attempt < retries - 1:  # 如果还有重试机会
                wait_time = 2 ** attempt  # 采用指数退避策略进行等待
                print(f"Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                print(f"Exceeded maximum retries for {file_path}. Skipping.")
                return

# 4. 主函数：从 YAML 配置下载文件夹中的所有文件
def main(yaml_file_path, dataset_name):
    yaml_config = load_yaml_config(yaml_file_path)
    download_folder_files(yaml_config, dataset_name)

if __name__ == "__main__":
    # 使用 argparse 从命令行参数读取 YAML 配置文件路径和数据集名称
    parser = argparse.ArgumentParser(description="Download selected dataset files from Hugging Face Hub.")
    parser.add_argument('yaml_file_path', type=str, help="Path to the YAML configuration file.")
    parser.add_argument('dataset_name', type=str, nargs='?', default="lmms-lab/LLaVA-OneVision-Data", help="Hugging Face dataset name (default: lmms-lab/LLaVA-OneVision-Data).")
    args = parser.parse_args()

    # 调用主函数
    main(args.yaml_file_path, args.dataset_name)
