#%%
import os
import json

# 配置路径
train_json_path    = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/train/train.json'
train_image_dir    = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/train/image'
val_json_path      = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/val/val.json'
val_ans_json_path  = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/val/val_ans.json'

# 1. 检查 train.json 中的 image id 与 train/image 目录文件一致
with open(train_json_path, 'r', encoding='utf-8') as f:
    train_data = json.load(f)
train_ids = {entry['id'] for entry in train_data}

train_files = [
    os.path.splitext(fn)[0]
    for fn in os.listdir(train_image_dir)
    if fn.lower().endswith(('.jpg', '.jpeg', '.png'))
]
train_file_ids = set(train_files)

missing_in_folder = train_ids - train_file_ids
extra_in_folder   = train_file_ids - train_ids

if missing_in_folder:
    print("⚠️ 以下 train.json 中的 ID 在 train/image 目录找不到对应文件：", missing_in_folder)
else:
    print("✅ train.json 中所有 ID 都在 train/image 目录找到对应文件。")

if extra_in_folder:
    print("⚠️ train/image 目录中存在未在 train.json 列出的文件 ID：", extra_in_folder)
else:
    print("✅ train/image 目录中没有多余的文件。")

# 2. 检查 val.json 与 val_ans.json 的 question_id 一致
with open(val_json_path, 'r', encoding='utf-8') as f:
    val_data = json.load(f)
val_ids = {entry['question_id'] for entry in val_data}

with open(val_ans_json_path, 'r', encoding='utf-8') as f:
    val_ans_data = json.load(f)
val_ans_ids = {entry['question_id'] for entry in val_ans_data}

if val_ids == val_ans_ids:
    print("✅ val.json 和 val_ans.json 的 question_id 完全一致。")
else:
    print("⚠️ val.json 与 val_ans.json 的 question_id 不一致：")
    print("   只在 val.json 中的 IDs：", val_ids - val_ans_ids)
    print("   只在 val_ans.json 中的 IDs：", val_ans_ids - val_ids)

# 3. 检查 train 和 val 之间没有重复的 ID
overlap = train_ids & val_ids
if overlap:
    print("⚠️ train 和 val 之间存在重复的 ID：", overlap)
else:
    print("✅ train 和 val 之间没有重复的 ID。")

# %%
import os
import json
import pandas as pd
from collections import Counter

# 路径配置
train_json_path    = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/train/train.json'
val_json_path      = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/val/val.json'
val_ans_json_path  = '/fs/scratch/PAS2099/Jiacheng/Scene_classification_v2/val/val_ans.json'
labels_csv_path    = '/fs/scratch/PAS2099/Jiacheng/Places_merge/output/llm_optimized_scene_labels_v2.csv'

# 1. 读取 CSV 中的第二列（optimized_label）
df_labels = pd.read_csv(labels_csv_path)
csv_labels = set(df_labels['optimized_label'])

# 2. 读取 train.json，提取 GPT 回答中的标签（去掉编号）
with open(train_json_path, 'r', encoding='utf-8') as f:
    train_data = json.load(f)

train_labels = []
for entry in train_data:
    ans = entry['conversations'][1]['value']      # 形如 "2. bedroom"
    label = ans.split('. ', 1)[1]                 # 取编号后的标签部分
    if not label.startswith('remote sensing'):    # 过滤非遥感标签
        train_labels.append(label)

# 3. 读取 val_ans.json，提取答案中的标签
with open(val_ans_json_path, 'r', encoding='utf-8') as f:
    val_ans_data = json.load(f)

val_labels = []
for entry in val_ans_data:
    ans = entry['text']                           # 形如 "5. kitchen"
    label = ans.split('. ', 1)[1]
    if not label.startswith('remote sensing'):    # 同样过滤
        val_labels.append(label)

# 4. 合并并统计出现次数
all_non_remote = train_labels + val_labels
counts = Counter(all_non_remote)

# 5. 比对 CSV 列表：只保留 CSV 中的非遥感标签
csv_non_remote = {lbl for lbl in csv_labels if not lbl.startswith('remote sensing')}
data_non_remote = set(counts.keys())

missing_in_data = csv_non_remote - data_non_remote
extra_in_data   = data_non_remote - csv_non_remote

# 6. 输出结果
print(f"Distinct non-remote-sensing labels in data: {len(data_non_remote)}")
print(f"Distinct non-remote-sensing labels in CSV:  {len(csv_non_remote)}")
if missing_in_data:
    print("⚠️ CSV 有但数据中未使用的标签:", missing_in_data)
else:
    print("✅ CSV 中所有非遥感标签都在数据中出现过。")
if extra_in_data:
    print("⚠️ 数据中有但 CSV 未包含的标签:", extra_in_data)
else:
    print("✅ 数据中所有非遥感标签都出现在 CSV 中。")

print("\nLabel usage counts:")
for label, cnt in counts.most_common():
    print(f"  {label}: {cnt}")

# %%
