#%%
import os
import matplotlib.pyplot as plt
from collections import defaultdict

# 设置路径
root_dir = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v4'
dataset_counts = defaultdict(int)

# 遍历所有 attribute 文件夹
for attr in os.listdir(root_dir):
    attr_path = os.path.join(root_dir, attr)
    if not os.path.isdir(attr_path):
        continue

    # 遍历该 attribute 下所有图片
    for fname in os.listdir(attr_path):
        if '_' not in fname:
            continue
        original_dataset = fname.split('_')[0]
        dataset_counts[original_dataset] += 1

# 打印统计结果
print("Original dataset : Number of images")
for ds, count in sorted(dataset_counts.items()):
    print(f"  {ds}: {count}")

# 画图
plt.figure(figsize=(10, 5))
plt.bar(dataset_counts.keys(), dataset_counts.values(), color='steelblue')
plt.xlabel('Original Dataset')
plt.ylabel('Number of Images')
plt.title('Original Dataset: Number of Images')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# %%
