#%%
import os
import matplotlib.pyplot as plt

# 根目录：merged_optimized_data
merged_root = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/balanced_data_v4'

# 统计每个标签的图片数量，并过滤掉 0
counts = {}
for label in sorted(os.listdir(merged_root)):
    label_dir = os.path.join(merged_root, label)
    if not os.path.isdir(label_dir):
        continue
    n = sum(
        1 for fn in os.listdir(label_dir)
        if fn.lower().endswith(('.jpg', '.jpeg', '.png'))
    )
    if n > 0:
        counts[label] = n

labels = list(counts.keys())
values = list(counts.values())

# 每 60 个标签一组，画一张条形图（bar chart）
chunk_size = 60
for i in range(0, len(labels), chunk_size):
    chunk_labels = labels[i:i+chunk_size]
    chunk_counts = [counts[lbl] for lbl in chunk_labels]

    x = range(len(chunk_labels))
    plt.figure(figsize=(20, 6))
    plt.bar(x, chunk_counts)
    plt.xticks(x, chunk_labels, rotation=90)
    plt.ylabel('Number of images')
    plt.title(f'Label indices {i+1}–{i+len(chunk_labels)} (out of {len(labels)})')
    plt.tight_layout()
    plt.show()

    # —— 新增逻辑：打印总数和明细 —— #
total_images = sum(counts.values())
total_attributes = len(counts)

print(f"1. Total number of images: {total_images}")
print(f"2. Total number of attributes (labels): {total_attributes}")
print("3. Label : Number of images")
for label, n in counts.items():
    print(f"   {label} : {n}")
# —— 以上新增结束 —— #

# %%

# %%
