import os
import pandas as pd
import shutil
import argparse

def main(bird_type, place_type, dataset_root_path, output_root_path):
    # 定义鸟类和地点的映射
    bird_type_mapping = {'land_bird': 0, 'water_bird': 1}
    place_type_mapping = {'land': 0, 'water': 1}

    # 检查输入的bird_type和place_type是否有效
    if bird_type not in bird_type_mapping:
        raise ValueError(f"无效的bird_type：{bird_type}，有效值为：{list(bird_type_mapping.keys())}")
    if place_type not in place_type_mapping:
        raise ValueError(f"无效的place_type：{place_type}，有效值为：{list(place_type_mapping.keys())}")

    # 将输入的鸟类和地点类型转换为对应的数值
    bird_type_value = bird_type_mapping[bird_type]
    place_type_value = place_type_mapping[place_type]

    # 读取CSV文件
    csv_file = os.path.join(dataset_root_path, 'metadata.csv')
    df = pd.read_csv(csv_file, sep=',')

    # 过滤数据
    filtered_df = df[
        (df['y'] == bird_type_value) &
        (df['place'] == place_type_value)
    ]

    # 获取数据集划分
    split_mapping = {0: 'train', 1: 'val', 2: 'test'}

    # 根据输入的bird_type和place_type生成新文件夹名称
    new_folder_name = f"{bird_type}_{place_type}"

    # 遍历过滤后的数据并复制文件
    for index, row in filtered_df.iterrows():
        split_folder = split_mapping[row['split']]
        # 在输出路径下创建新的目录结构
        dest_dir = os.path.join(output_root_path, new_folder_name, split_folder)

        # 获取图像的相对路径（包含标签文件夹）
        img_relative_path = row['img_filename']

        # 源文件路径
        src_file = os.path.join(dataset_root_path, img_relative_path)
        # 目标文件路径
        dest_file = os.path.join(dest_dir, img_relative_path)

        # 创建目标文件夹（如果不存在）
        os.makedirs(os.path.dirname(dest_file), exist_ok=True)

        # 复制文件
        shutil.copy(src_file, dest_file)

    print(f"数据集抽取并保存完成！共处理了{len(filtered_df)}个文件。")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='数据集抽取脚本')
    parser.add_argument('--bird_type', required=True, help='鸟类类型，例如：--bird_type water_bird')
    parser.add_argument('--place_type', required=True, help='地点类型，例如：--place_type water')
    parser.add_argument('--dataset_root_path', required=True, help='数据集根路径，例如：--dataset_root_path /path/to/dataset')
    parser.add_argument('--output_root_path', required=True, help='输出数据集根路径，例如：--output_root_path /path/to/output_dataset')

    args = parser.parse_args()
    main(args.bird_type, args.place_type, args.dataset_root_path, args.output_root_path)


    # --bird_type water_bird
    # --place_type water
    # --dataset_root_path /path/to/dataset
    # --output_root_path /path/to/output_dataset