import traceback
import os
import pandas as pd
import yaml
from math import sqrt
import shutil
import json
import argparse
import sys
import uuid

# 解析命令行参数
parser = argparse.ArgumentParser(description="Process datasets from YAML configuration.")
parser.add_argument("config_path", type=str, help="Path to the YAML configuration file")
args = parser.parse_args()

# 读取 YAML 配置文件
with open(args.config_path, "r") as file:
    config = yaml.safe_load(file)

# 限制每个数据集的大小为 40,000 条数据
MAX_ROWS = 40000

totalnum = 0
# 遍历每个数据集的配置
for dataset in config['datasets']:
    sampling_strategy = dataset['sampling_strategy']
    folder_name = dataset['folder_name']
    
    # 获取对应文件夹中的所有 Parquet 文件
    folder_path = os.path.join(os.getcwd(), folder_name)
    parquet_files = [f for f in os.listdir(folder_path) if f.endswith('.parquet')]

    # 逐个处理每个 Parquet 文件，避免内存过载
    total_rows = 0
    for parquet_file in parquet_files:
        file_path = os.path.join(folder_path, parquet_file)
        try:
            df = pd.read_parquet(file_path)
            total_rows += len(df)
        except Exception as e:
            print(f"Error reading {file_path}: {e}")
            sys.exit(1)  # 如果第一个循环出错，直接退出程序
    
    # 根据 sampling_strategy 计算需要的行数
    if sampling_strategy == "all":
        sampled_rows = total_rows
    elif sampling_strategy.startswith("first:"):
        percentage = int(sampling_strategy.split(':')[1].strip('%')) / 100
        sampled_rows = int(total_rows * percentage)
    else:
        sampled_rows = total_rows  # 默认全部（可以改为其他逻辑）

    # 如果数据集超过 40,000 条数据，则仅取前 40,000 条
    if sampled_rows > MAX_ROWS:
        sampled_rows = MAX_ROWS
    totalnum += sqrt(sampled_rows)

# 按照加权比例调整每个数据集的样本量
for dataset in config['datasets']:
    folder_name = dataset['folder_name']
    sampling_strategy = dataset['sampling_strategy']
    json_path = dataset['json_path']

    try:
        # 获取对应文件夹中的所有 Parquet 文件
        folder_path = os.path.join(os.getcwd(), folder_name)
        parquet_files = [f for f in os.listdir(folder_path) if f.endswith('.parquet')]

        # 逐个处理每个 Parquet 文件，避免内存过载
        all_data = []
        for parquet_file in parquet_files:
            file_path = os.path.join(folder_path, parquet_file)
            print(f"Processing Parquet file: {file_path}")  # 输出当前处理的文件路径
            try:
                df = pd.read_parquet(file_path)
                all_data.append(df)
            except Exception as e:
                print(f"Error reading {file_path}: {e}")
                continue  # 如果当前数据集的某个 Parquet 文件出错，跳过当前文件

        if len(all_data) == 0:
            print(folder_name)
            continue
        # 将所有数据合并为一个 DataFrame
        combined_df = pd.concat(all_data, ignore_index=True) if len(all_data)>1 else all_data[0]
        del all_data
        # 根据 sampling_strategy 进行采样
        if sampling_strategy == "all":
            sampled_df = combined_df
        elif sampling_strategy.startswith("first:"):
            percentage = int(sampling_strategy.split(':')[1].strip('%')) / 100
            num_samples = int(len(combined_df) * percentage)
            sampled_df = combined_df.head(num_samples)
        else:
            sampled_df = combined_df.sample(frac=1)  # 随机打乱数据
        
        # 如果数据集超过 40,000 条数据，则仅取前 40,000 条
        if len(sampled_df) > MAX_ROWS:
            sampled_df = sampled_df.head(MAX_ROWS)

        # 计算加权比例
        weight = sqrt(len(sampled_df)) / totalnum
        print(f"Dataset {folder_name} weight: {weight:.4f}")
        num_samples = int(len(sampled_df) * weight)

        # 仅保留按比例调整的样本量
        sampled_df = sampled_df.head(num_samples)

        # 处理每个数据集
        images_dir = os.path.join(os.getcwd(), folder_name, "images")
        os.makedirs(images_dir, exist_ok=True)

        records = []
        for _, row in sampled_df.iterrows():
            image = row['image']
            unique_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(row['conversations'])))
            image_name = f"{unique_id}.png"
            image_path = os.path.join(images_dir, image_name)
            # 将图片保存到指定路径
            if image is not None and "bytes" in image:
                with open(image_path, 'wb') as out_file:
                    out_file.write(image["bytes"])
            else:
                print("Image or 'bytes' key is None, skipping this entry")
                continue


            row_dict = {
                'id': unique_id,
                'image': f"images/{image_name}",
                'conversations': list(row['conversations'])
            }
            # 保存每条数据到记录
            records.append(row_dict)

        # 将更新后的记录保存为 JSON 文件
        json_output_path = os.path.join(os.getcwd(), folder_name, os.path.basename(json_path))
        os.makedirs(os.path.dirname(json_output_path), exist_ok=True)

        # 保存更新后的 JSON 数据
        with open(json_output_path, 'w', encoding='utf-8') as json_file:
            json.dump(records, json_file, ensure_ascii=False, indent=4)

    except Exception as e:
        print(f"Error processing dataset {folder_name} ({json_path}): {e}")
        print(f"Error occurred at line {traceback.format_exc()}")
        print(totalnum)
        continue  # 如果第二个循环出错，跳过当前数据集并继续下一个

    # 清理内存
    del sampled_df

