#%%
import os
import re
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm

# -------------------------
# 1. Define paths
# -------------------------
datasets = {
    'FMD':         '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/FMD/image',
    'KTH':         '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/KTH_merged',
    'Kylberg':     '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/Kylberg/kylberg_dataset',
    'Textual_dtd': '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/Textual_dtd/images',
}
merged_root = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/merged_dataset'
valid_exts = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}
os.makedirs(merged_root, exist_ok=True)

# -------------------------
# 2. Prepare counters
# -------------------------
dataset_attr_counts   = {ds: {} for ds in datasets}
attribute_total_counts = {}
attribute_datasets     = {}

# -------------------------
# 3. Merge images
# -------------------------
for ds_name, root_dir in tqdm(datasets.items(), desc='Datasets', unit='ds'):
    if not os.path.isdir(root_dir):
        raise FileNotFoundError(f"Missing folder for {ds_name}: {root_dir}")
    for entry in tqdm(os.listdir(root_dir),
                      desc=f'{ds_name} attrs', leave=False, unit='attr'):
        entry_path = os.path.join(root_dir, entry)
        if not os.path.isdir(entry_path):
            continue

        # determine attribute, lowercase
        if ds_name == 'Kylberg':
            base = entry.split('-',1)[0]
            attribute = re.sub(r'\d+$','',base)
        else:
            attribute = entry
        attribute = attribute.lower()

        attribute_datasets.setdefault(attribute, set()).add(ds_name)
        dst = os.path.join(merged_root, attribute)
        os.makedirs(dst, exist_ok=True)
        dataset_attr_counts[ds_name].setdefault(attribute, 0)

        files = [f for f in os.listdir(entry_path)
                 if os.path.splitext(f)[1].lower() in valid_exts]
        for fname in tqdm(files,
                          desc=f'{ds_name}/{attribute}',
                          leave=False, unit='img'):
            cnt = dataset_attr_counts[ds_name][attribute] + 1
            dataset_attr_counts[ds_name][attribute] = cnt
            ext = os.path.splitext(fname)[1].lower()
            new_name = f"{ds_name}_{attribute}_{cnt}{ext}"
            shutil.copy2(
                os.path.join(entry_path, fname),
                os.path.join(dst, new_name)
            )

# -------------------------
# 4. Compute totals & overlaps
# -------------------------
for attr, ds_set in attribute_datasets.items():
    total = sum(dataset_attr_counts[ds].get(attr,0) for ds in datasets)
    attribute_total_counts[attr] = total

overlapped = [a for a,ds in attribute_datasets.items() if len(ds)>1]

# -------------------------
# 5. Plot sorted histogram
# -------------------------
# sort attributes by total count descending
sorted_attrs = sorted(attribute_total_counts.items(),
                      key=lambda x: x[1], reverse=True)
attrs, counts = zip(*sorted_attrs)

# set colors
colors = ['red' if a in overlapped else 'steelblue' for a in attrs]

# plot
plt.figure(figsize=(15,7))
plt.bar(attrs, counts, color=colors)
plt.xticks(rotation=90)
plt.xlabel('Attribute')
plt.ylabel('Number of Images')
plt.title('Merged Dataset: Images per Attribute\n(red = in multiple datasets)')
plt.tight_layout()
plt.show()

# -------------------------
# 6. Print summary
# -------------------------
total_images  = sum(counts)
total_attrs   = len(attrs)
total_overlap = len(overlapped)

print(f"1. Total number of images: {total_images}")
print(f"2. Total number of attributes: {total_attrs}")
print(f"3. Total number of overlapped attributes: {total_overlap}\n")

print("4. Attribute : Number of images")
for a, c in sorted_attrs:
    print(f"   {a}: {c}")
print()

print("5. Overlapped attribute breakdown (dataset: count)")
for a in overlapped:
    ds_counts = {ds: dataset_attr_counts[ds].get(a,0)
                 for ds in datasets if dataset_attr_counts[ds].get(a,0)>0}
    line = ", ".join(f"{ds}: {cnt}" for ds,cnt in ds_counts.items())
    print(f"   {a}: {line}")

# %%
