#%%
import os
import json
from collections import defaultdict

# === Paths to your split JSONs ===
TRAIN_SPLIT_JSON = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/train/train.json"
VAL_SPLIT_JSON   = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/val/val.json"

# === Load split data ===
with open(TRAIN_SPLIT_JSON, "r") as f:
    train_data = json.load(f)
with open(VAL_SPLIT_JSON, "r") as f:
    val_data = json.load(f)

# === Group unique images by (object, orientation) ===
train_groups = defaultdict(set)
for e in train_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    train_groups[key].add(os.path.basename(e["image"]))

val_groups = defaultdict(set)
for e in val_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    val_groups[key].add(os.path.basename(e["image"]))

# === Filter combos with total images >= 5 and print with percentage ===
print("Object\tOrientation\tTrain\tVal\tTotal\tTrain%")
for key in sorted(set(train_groups) | set(val_groups)):
    t = len(train_groups.get(key, set()))
    v = len(val_groups.get(key, set()))
    tot = t + v
    if tot >= 5:
        pct = (t / tot * 100) if tot else 0
        obj, ori = key
        print(f"{obj}\t{ori}\t{t}\t{v}\t{tot}\t{pct:.2f}%")


# %%
import os
import json
from collections import defaultdict

# === Paths to your split JSONs ===
TRAIN_SPLIT_JSON = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/train/train.json"
VAL_SPLIT_JSON   = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/val/val.json"

# === Load split data ===
with open(TRAIN_SPLIT_JSON, "r") as f:
    train_data = json.load(f)
with open(VAL_SPLIT_JSON, "r") as f:
    val_data = json.load(f)

# === Group unique images by (object, orientation) ===
train_groups = defaultdict(set)
for e in train_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    train_groups[key].add(os.path.basename(e["image"]))

val_groups = defaultdict(set)
for e in val_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    val_groups[key].add(os.path.basename(e["image"]))

# === Filter combos where total images is a multiple of 5 and print percentage ===
print("Object\tOrientation\tTrain\tVal\tTotal\tTrain%")
for key in sorted(set(train_groups) | set(val_groups)):
    t = len(train_groups.get(key, set()))
    v = len(val_groups.get(key, set()))
    tot = t + v
    # only consider totals that are multiples of 5
    if tot % 5 == 0 and tot > 0:
        pct = (t / tot * 100)
        obj, ori = key
        print(f"{obj}\t{ori}\t{t}\t{v}\t{tot}\t{pct:.2f}%")

# %%
import os
import json
from collections import defaultdict

# === Paths to your split JSONs ===
TRAIN_SPLIT_JSON = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/train/train.json"
VAL_SPLIT_JSON   = "/fs/scratch/PAS2099/Jiacheng/Orientation/EgoOrient/val/val.json"

# === Load split data ===
with open(TRAIN_SPLIT_JSON, "r") as f:
    train_data = json.load(f)
with open(VAL_SPLIT_JSON, "r") as f:
    val_data = json.load(f)

# === Group unique images by (object, orientation) ===
train_groups = defaultdict(set)
for e in train_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    train_groups[key].add(os.path.basename(e["image"]))

val_groups = defaultdict(set)
for e in val_data:
    if e.get("type") != "general_complex":
        continue
    key = (e["category_name"], e["label"].strip().lower())
    val_groups[key].add(os.path.basename(e["image"]))

# === Check each (object, orientation) where total is multiple of 5 ===
print("Object\tOrientation\tTrain\tVal\tTotal\tTrain%")
all_ok = True
for key in sorted(set(train_groups) | set(val_groups)):
    t = len(train_groups.get(key, set()))
    v = len(val_groups.get(key, set()))
    tot = t + v
    if tot > 0 and tot % 5 == 0:
        pct = t / tot * 100
        obj, ori = key
        print(f"{obj}\t{ori}\t{t}\t{v}\t{tot}\t{pct:.2f}%")
        # require exactly 80%
        if t != round(0.8 * tot):
            all_ok = False

# === Final verdict ===
print("true" if all_ok else "false")

# %%
