import csv

def read_csv(file_path):
    """读取CSV文件并返回包含所有行的列表"""
    with open(file_path, 'r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        return [row for row in reader]

def write_csv(file_path, data, headers):
    """将数据写入CSV文件"""
    with open(file_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=headers)
        writer.writeheader()
        writer.writerows(data)

# 定义输入和输出文件路径
file1 = 'hmdb_51_base_labels.csv'
file2 = 'hmdb51_labels.csv'
output_file = 'output.csv'

# 读取两个CSV文件
data1 = read_csv(file1)
data2 = read_csv(file2)

# 获取第一个CSV文件中的所有name值
names1 = {row['name'] for row in data1}

# 从第二个CSV文件中挑选出name重复的项
duplicate_rows=[]
count=0
for row in data2:
    if row['name'] in names1:
        duplicate_rows.append(row)
        count+=1
print('total count: ', count)
# duplicate_rows = [row for row in data2 if row['name'] in names1]

# 写入新的CSV文件
headers = ['id', 'name']
write_csv(output_file, duplicate_rows, headers)

print("文件处理完成，重复项已写入新的CSV文件。")