#%%
import os
import json
from collections import Counter

# Paths
train_json_path   = '/fs/scratch/PAS2099/Jiacheng/Orientation/train/train.json'
val_json_path     = '/fs/scratch/PAS2099/Jiacheng/Orientation/val/val.json'
val_ans_json_path = '/fs/scratch/PAS2099/Jiacheng/Orientation/val/val_ans.json'

# Load train data
with open(train_json_path, 'r') as f:
    train_data = json.load(f)
train_images = {os.path.basename(item['image']) for item in train_data}

# Load val data
with open(val_json_path, 'r') as f:
    val_data = json.load(f)
val_images = {item['image'] for item in val_data}

# 1. Check overlap
overlap = train_images & val_images
print("Overlap between train and val images:", overlap or "None")

# 2. Print counts
print(f"Number of train images: {len(train_images)}")
print(f"Number of val images:   {len(val_images)}")

# 3. Count orientations across all data
direction_counts = Counter()

# From train.json gpt responses
for item in train_data:
    gpt = next(conv['value'] for conv in item['conversations'] if conv['from']=='gpt')
    # e.g. "8. front left" -> "front left"
    direction = gpt.split('. ', 1)[1]
    direction_counts[direction] += 1

# From val_ans.json entries
with open(val_ans_json_path, 'r') as f:
    val_ans = json.load(f)
for ans in val_ans:
    # ans['text'] e.g. "1. front"
    direction = ans['text'].split('. ', 1)[1]
    direction_counts[direction] += 1

print("\nOrientation distribution (direction: count):")
for direction, cnt in direction_counts.items():
    print(f"  {direction}: {cnt}")

# %%
