#%%
import os
import shutil
import matplotlib.pyplot as plt

# 路径定义
root1 = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/KTH1_original/KTH_TIPS'
root2 = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/KTH2_original/KTH-TIPS2-b'
merged_root = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/KTH_merged'

# 支持的图像扩展名
valid_exts = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff'}

# 创建合并目录
os.makedirs(merged_root, exist_ok=True)

# --- 1. 复制 KTH-TIPS1 图像 ---
for attr in os.listdir(root1):
    src_dir = os.path.join(root1, attr)
    if not os.path.isdir(src_dir):
        continue
    dst_dir = os.path.join(merged_root, attr)
    os.makedirs(dst_dir, exist_ok=True)
    for fname in os.listdir(src_dir):
        ext = os.path.splitext(fname)[1].lower()
        if ext in valid_exts:
            shutil.copy2(os.path.join(src_dir, fname),
                         os.path.join(dst_dir, fname))

# --- 2. 复制 KTH-TIPS2-b 图像 ---
for attr in os.listdir(root2):
    src_attr = os.path.join(root2, attr)
    if not os.path.isdir(src_attr):
        continue
    dst_dir = os.path.join(merged_root, attr)
    os.makedirs(dst_dir, exist_ok=True)
    for sample in os.listdir(src_attr):
        sample_dir = os.path.join(src_attr, sample)
        if not os.path.isdir(sample_dir):
            continue
        for fname in os.listdir(sample_dir):
            ext = os.path.splitext(fname)[1].lower()
            if ext in valid_exts:
                dst_path = os.path.join(dst_dir, fname)
                # 若同名文件已存在，则加上 sample 前缀避免覆盖
                if os.path.exists(dst_path):
                    fname = f"{sample}_{fname}"
                    dst_path = os.path.join(dst_dir, fname)
                shutil.copy2(os.path.join(sample_dir, fname if 'sample_' in fname else fname),
                             dst_path)

# --- 3. 统计合并后结果 ---
attribute_counts = {}
for attr in os.listdir(merged_root):
    d = os.path.join(merged_root, attr)
    if os.path.isdir(d):
        count = sum(1 for f in os.listdir(d)
                    if os.path.splitext(f)[1].lower() in valid_exts)
        attribute_counts[attr] = count

total_attributes = len(attribute_counts)
total_images = sum(attribute_counts.values())

# 打印统计信息
print(f"1. Number of total attributes: {total_attributes}")
print(f"2. Number of total images: {total_images}")
print("3. Attribute: number of images")
for attr, cnt in sorted(attribute_counts.items()):
    print(f"   {attr}: {cnt}")

# --- 4. 绘制直方图 ---
plt.figure(figsize=(12, 6))
plt.bar(attribute_counts.keys(), attribute_counts.values())
plt.xticks(rotation=45, ha='right')
plt.xlabel('Attribute (Class)')
plt.ylabel('Number of Images')
plt.title('Merged KTH-TIPS1 & KTH-TIPS2-b: Images per Attribute')
plt.tight_layout()
plt.show()

# %%
