import json
import huggingface_hub
from datasets import Dataset, DatasetDict
from huggingface_hub import login
import argparse
import json
import os
import sys
import time
import shutil

# ============ HuggingFace 超时和重试配置 ============
HF_MAX_RETRIES = 5  # 最大重试次数
HF_RETRY_BASE_WAIT = 10  # 重试基础等待时间（秒），每次递增

STORAGE_PATH = os.getenv("STORAGE_PATH")
HUGGINGFACENAME = os.getenv("HUGGINGFACENAME")
print(STORAGE_PATH)
with open('tokens.json', 'r') as f:
    token = json.load(f)['huggingface']
login(token=token)


def push_to_hub_with_retry(
    dataset,
    repo_name: str,
    config_name: str = None,
    private: bool = True,
    max_retries: int = HF_MAX_RETRIES
) -> bool:
    """
    带重试机制的 HuggingFace 数据集上传。
    
    Args:
        dataset: 要上传的数据集 (Dataset 或 DatasetDict)
        repo_name: HuggingFace 仓库名
        config_name: 配置名（可选）
        private: 是否设为私有仓库
        max_retries: 最大重试次数
    
    Returns:
        是否上传成功
    """
    for attempt in range(max_retries):
        try:
            print(f"[HuggingFace] Uploading (attempt {attempt + 1}/{max_retries})...")
            
            if config_name:
                dataset.push_to_hub(repo_name, private=private, config_name=config_name)
            else:
                dataset.push_to_hub(repo_name, private=private)
            
            print(f"[HuggingFace] Upload successful!")
            return True
            
        except Exception as e:
            if attempt < max_retries - 1:
                wait_time = HF_RETRY_BASE_WAIT * (attempt + 1)
                print(f"[HuggingFace] Upload error: {e}")
                print(f"[HuggingFace] Retrying in {wait_time}s...")
                time.sleep(wait_time)
            else:
                print(f"[HuggingFace] Upload failed after {max_retries} attempts: {e}")
                return False
    
    return False
parser = argparse.ArgumentParser()
parser.add_argument("--repo_name", type=str, default="")
parser.add_argument("--max_score", type=float, default=0.7)
parser.add_argument("--min_score", type=float, default=0.3)
parser.add_argument("--experiment_name", type=str, default="Qwen_Qwen3-4B-Base_all")
args = parser.parse_args()

# 动态获取GPU数量，支持6卡/8卡等不同配置
n_gpus = int(os.getenv("TOTAL_GPU_COUNT", "8"))

datas= []
for i in range(n_gpus):
    try:
        with open(f'{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json', 'r') as f:
            data = json.load(f)
            datas.extend(data)
    except:
        print(f"File {args.experiment_name}_{i}_results.json not found")
        continue


# 注释掉删除逻辑，保留生成的问题文件
# for i in range(8):
#     try:
#         os.remove(f'{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json')
#     except:
#         print(f"File {args.experiment_name}_{i}_results.json not found")
#         continue

scores = [data['score'] for data in datas]
#  print the distribution of scores
import matplotlib.pyplot as plt
plt.hist(scores, bins=11)

# 创建归档目录并保存分数分布图
archive_dir = f'{STORAGE_PATH}/generated_question/archive/{args.experiment_name}'
os.makedirs(archive_dir, exist_ok=True)
plt.savefig(f'{archive_dir}/scores_distribution_{args.experiment_name}.png')

# 将生成的数据文件归档到archive目录
for i in range(n_gpus):
    src_file = f'{STORAGE_PATH}/generated_question/{args.experiment_name}_{i}_results.json'
    if os.path.exists(src_file):
        dst_file = f'{archive_dir}/{args.experiment_name}_{i}_results.json'
        shutil.move(src_file, dst_file)
        print(f"Archived: {src_file} -> {dst_file}")

#count the number  of score between 0.2 and 0.8 
if not args.repo_name == "":
    filtered_datas = [{'problem':data['question'],'answer':data['answer'],'score':data['score']} for data in datas if data['score'] >= args.min_score and data['score'] <= args.max_score and data['answer'] != '' and data['answer']!= 'None']
    print(len(filtered_datas))
    train_dataset = Dataset.from_list(filtered_datas)
    dataset_dict = {"train": train_dataset}
    config_name = f"{args.experiment_name}"
    dataset = DatasetDict(dataset_dict)
    
    # 使用带重试机制的上传
    upload_success = push_to_hub_with_retry(
        dataset,
        f"{HUGGINGFACENAME}/{args.repo_name}",
        config_name=config_name,
        private=True
    )
    
    if not upload_success:
        print("[Upload] ERROR: Failed to upload dataset after all retries!")
        sys.exit(1)







