#%%
import json
from collections import Counter
import matplotlib.pyplot as plt

# Path to the JSON file with image-label records
INPUT_JSON = "/fs/scratch/PAS2099/Jiacheng/place365/output/places365_val_data.json"

def main():
    # 1. Load data
    with open(INPUT_JSON, "r") as f:
        data = json.load(f)

    # 2. Count images per label
    counts = Counter(record["label"] for record in data)

    # 3. Print totals
    total_images = sum(counts.values())
    total_attributes = len(counts)
    print(f"Total number of images: {total_images}")
    print(f"Total number of attributes: {total_attributes}\n")

    # 4. Print each label with its count
    for label, count in sorted(counts.items()):
        print(f"{label}: {count}")

    # 5. Plot bar chart of images per label
    labels = list(counts.keys())
    values = list(counts.values())

    plt.figure(figsize=(20, 10))  # Increase figure size for readability
    plt.bar(range(len(labels)), values)
    plt.xticks(range(len(labels)), labels, rotation=90, fontsize=6)
    plt.xlabel("Label")
    plt.ylabel("Number of Images")
    plt.title("Number of Images per Label")
    plt.tight_layout()

    # To display the plot
    plt.show()
    # Alternatively, save to file:
    # plt.savefig("/fs/scratch/PAS2099/Jiacheng/place365/output/label_image_counts.png")

if __name__ == "__main__":
    main()

# %%
