#%%
import os
from collections import Counter
import matplotlib.pyplot as plt

# Path to the AID data
DATA_DIR = "/fs/scratch/PAS2099/Jiacheng/AID/data"

def main():
    # 1. Count images per label
    label_counts = Counter()
    total_images = 0

    for label in os.listdir(DATA_DIR):
        label_dir = os.path.join(DATA_DIR, label)
        if not os.path.isdir(label_dir):
            continue
        # Only count common image extensions
        imgs = [f for f in os.listdir(label_dir)
                if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        count = len(imgs)
        label_counts[label] = count
        total_images += count

    # 2. Print summaries
    total_attributes = len(label_counts)
    print(f"Total number of attributes: {total_attributes}")
    print(f"Total number of images: {total_images}\n")

    # 3. Print each label with its image count
    for label, count in sorted(label_counts.items()):
        print(f"{label}: {count}")

    # 4. Plot bar chart of images per label
    labels = list(label_counts.keys())
    counts = list(label_counts.values())

    plt.figure(figsize=(12, 6))
    plt.bar(labels, counts)
    plt.xticks(rotation=90)
    plt.xlabel("Label")
    plt.ylabel("Number of images")
    plt.title("AID Dataset: Number of Images per Label")
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

# %%
import os
from collections import Counter
import matplotlib.pyplot as plt

# Path to the AID data
DATA_DIR = "/fs/scratch/PAS2099/Jiacheng/AID/data"

def main():
    # 1. Count images per label
    label_counts = Counter()
    total_images = 0

    for label in os.listdir(DATA_DIR):
        label_dir = os.path.join(DATA_DIR, label)
        if not os.path.isdir(label_dir):
            continue
        # Only count common image extensions
        imgs = [f for f in os.listdir(label_dir)
                if f.lower().endswith((".jpg", ".jpeg", ".png"))]
        count = len(imgs)
        label_counts[label] = count
        total_images += count

    # 2. Sort labels by descending count
    sorted_items = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
    sorted_labels, sorted_counts = zip(*sorted_items)

    # 3. Print summaries
    total_attributes = len(label_counts)
    print(f"Total number of attributes: {total_attributes}")
    print(f"Total number of images: {total_images}\n")

    # 4. Print each label with its image count (sorted)
    for label, count in sorted_items:
        print(f"{label}: {count}")

    # 5. Plot sorted histogram
    plt.figure(figsize=(12, 6))
    plt.bar(sorted_labels, sorted_counts)
    plt.xticks(rotation=90)
    plt.xlabel("Label")
    plt.ylabel("Number of images")
    plt.title("AID Dataset: Number of Images per Label (Sorted)")
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()

# %%
