#%%
import os
import re
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm

# 1. Set root directory
merged_root = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset_v4'

# 2. Init counters
attribute_total_counts = {}
attribute_dataset_counts = defaultdict(lambda: defaultdict(int))  # attr -> ds -> count

# 3. Traverse all attribute folders
for attr in tqdm(os.listdir(merged_root), desc="Processing attributes"):
    attr_dir = os.path.join(merged_root, attr)
    if not os.path.isdir(attr_dir):
        continue

    total = 0
    for fname in os.listdir(attr_dir):
        ext = os.path.splitext(fname)[1].lower()
        if ext not in {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}:
            continue

        total += 1
        match = re.match(r"([A-Za-z0-9]+)_", fname)
        if match:
            dataset = match.group(1)
            attribute_dataset_counts[attr][dataset] += 1

    attribute_total_counts[attr] = total

# 4. Identify overlapping attributes
overlapped_attrs = [attr for attr, ds_map in attribute_dataset_counts.items() if len(ds_map) > 1]

# 5. Plot histogram
sorted_items = sorted(attribute_total_counts.items(), key=lambda x: x[1], reverse=True)
sorted_attrs = [x[0] for x in sorted_items]
sorted_counts = [x[1] for x in sorted_items]
colors = ['red' if attr in overlapped_attrs else 'steelblue' for attr in sorted_attrs]

plt.figure(figsize=(16, 6))
plt.bar(sorted_attrs, sorted_counts, color=colors)
plt.xticks(rotation=90)
plt.xlabel("Attribute")
plt.ylabel("Number of Images")
plt.title("Images per Attribute (red = appears in multiple datasets)")
plt.tight_layout()
plt.show()

# 6. Print statistics
total_images = sum(attribute_total_counts.values())
total_attributes = len(attribute_total_counts)
total_overlapped = len(overlapped_attrs)

print(f"1. Total number of images: {total_images}")
print(f"2. Total number of attributes: {total_attributes}")
print(f"3. Total number of overlapped attributes: {total_overlapped}")
print()

print("4. Attribute: Number of images")
for attr, count in sorted(attribute_total_counts.items()):
    print(f"   {attr}: {count}")
print()

print("5. Overlapped attributes with dataset breakdown:")
for attr in sorted(overlapped_attrs):
    ds_counts = attribute_dataset_counts[attr]
    parts = [f"{ds}: {cnt}" for ds, cnt in ds_counts.items()]
    print(f"   {attr}: " + ", ".join(parts))

# %%
