import pandas as pd


def read_txt_with_pandas(file_path, delimiter=None, skip_rows=0):
    """
    使用 pandas 读取以空格分隔的 TXT 文件

    参数:
    file_path (str): TXT 文件的路径
    delimiter (str): 分隔符，默认为 None（使用任意空白字符）
    skip_rows (int): 要跳过的起始行数

    返回:
    DataFrame: 包含数据的 pandas DataFrame
    """
    try:
        # 读取文件
        df = pd.read_csv(
            file_path,
            delimiter=delimiter if delimiter else r'\s+',
            header=None,  # 没有列标题
            names=['ID', 'train'],  # 指定列名
            skiprows=skip_rows,  # 跳过的行数
            engine='python'  # 使用 Python 引擎处理正则表达式分隔符
        )

        return df

    except Exception as e:
        print(f"读取文件时发生错误: {e}")
        return None


# 示例用法
if __name__ == "__main__":
    # 文件路径
    file_path = "/home/datasets/FGVC/CUB_200_2011/train_test_split.txt"  # 替换为您的文件路径

    # 读取文件
    train_df = read_txt_with_pandas(file_path)
    zero_count = (train_df['train'] == 1).sum()

    print(f"Value 列中值为 0 的元素总数: {zero_count}")


    loca_path = "/home/datasets/FGVC/CUB_200_2011/images.txt"  # 替换为您的文件路径
    loca_df = read_txt_with_pandas(loca_path)
    loca_df.columns=['ID', 'path']
    loca_df_dict = dict(zip(loca_df['ID'], loca_df['path']))

    classes_path = "/home/datasets/FGVC/CUB_200_2011/image_class_labels.txt"  # 替换为您的文件路径
    class_df = read_txt_with_pandas(classes_path)
    class_df.columns=['ID', 'labels']
    class_df_dict = dict(zip(class_df['ID'], class_df['labels']))

    train_classes=[]
    text_data = []
    for _,row in train_df.iterrows():
        images = loca_df_dict[row['ID']]
        labels = class_df_dict[row['ID']]
        if row['train'] == 0:
            train_classes.append(images + " " + str(labels))
        else:
            text_data.append(images + " " +str(labels))

    print(len(train_classes))

    with open("trainval.txt", 'w') as file:
        for item in train_classes:
            file.write(f"{item}\n")  # 写入每个元素并添加换行符

    with open("test.txt", 'w') as file:
        for item in text_data:
            file.write(f"{item}\n")  # 写入每个元素并添加换行符