#%%
import os
import json
import re
from collections import Counter, defaultdict

# Paths
train_json_path = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
val_json_path   = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'
val_ans_path    = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val_ans.json'
train_img_dir   = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/image'
val_img_dir     = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/image'

def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def extract_options(prompt_text):
    """
    Given a prompt like "<image>\nWhat... 1. opt1, 2. opt2, ... 10. opt10",
    return list of "opt1", "opt2", ..., "opt10".
    """
    # split after the last '? '
    parts = prompt_text.split('Choose one from below.', 1)
    if len(parts) < 2:
        return []
    opts_str = parts[1]
    # find all occurrences of "<number>. <option>"
    matches = re.findall(r'\d+\.\s*([^,]+)', opts_str)
    return [m.strip() for m in matches]

def check_json_integrity():
    # Load data
    train_data   = load_json(train_json_path)
    val_data     = load_json(val_json_path)
    val_ans_data = load_json(val_ans_path)

    # 1. JSON load validity
    print("JSON load: OK")

    # 2. Duplicate check within train.json
    train_ids    = [item['id'] for item in train_data]
    dup_train_ids = [id_ for id_, cnt in Counter(train_ids).items() if cnt > 1]
    train_imgs   = [os.path.basename(item['image']) for item in train_data]
    dup_train_imgs = [img for img, cnt in Counter(train_imgs).items() if cnt > 1]

    print("\nDuplicate IDs in train.json:", dup_train_ids or "None")
    print("Duplicate images in train.json:", dup_train_imgs or "None")

    # 3. Duplicate check within val.json
    val_qids     = [item['question_id'] for item in val_data]
    dup_val_qids = [qid for qid, cnt in Counter(val_qids).items() if cnt > 1]
    val_imgs     = [os.path.basename(item['image']) for item in val_data]
    dup_val_imgs = [img for img, cnt in Counter(val_imgs).items() if cnt > 1]

    print("\nDuplicate question_ids in val.json:", dup_val_qids or "None")
    print("Duplicate images in val.json:", dup_val_imgs or "None")

    # 4. ID uniqueness across train & val
    overlap_ids = set(train_ids) & set(val_qids)
    print("\nOverlap of IDs between train & val:", overlap_ids or "None")

    # 5. Overlap of images train vs val
    overlap_imgs = set(train_imgs) & set(val_imgs)
    print("Overlap of images between train & val:", overlap_imgs or "None")

    # 6. Existence check for train images
    existing_train = set(os.listdir(train_img_dir))
    missing_train = set(train_imgs) - existing_train
    extra_train   = existing_train - set(train_imgs)
    print(f"\nTrain images missing on disk: {sorted(missing_train) or ['None']}")
    print(f"Extra files in train dir: {sorted(extra_train) or ['None']}")

    # 7. Existence check for val images
    existing_val = set(os.listdir(val_img_dir))
    missing_val  = set(val_imgs) - existing_val
    extra_val    = existing_val - set(val_imgs)
    print(f"\nVal images missing on disk: {sorted(missing_val) or ['None']}")
    print(f"Extra files in val dir: {sorted(extra_val) or ['None']}")

    # 8. Options count & correctness in train.json
    bad_train_prompts = []
    bad_train_answers = []
    for item in train_data:
        human = next(conv for conv in item['conversations'] if conv['from']=='human')['value']
        gpt   = next(conv for conv in item['conversations'] if conv['from']=='gpt')['value']
        opts = extract_options(human)
        if len(opts) != 10:
            bad_train_prompts.append(item['id'])
        # check answer exists in opts
        if not any(gpt.endswith(opt) for opt in opts):
            bad_train_answers.append(item['id'])
    print(f"\nTrain prompts with ≠10 options: {bad_train_prompts or 'None'}")
    print(f"Train samples with invalid answer: {bad_train_answers or 'None'}")

    # 9. Options count & correctness in val.json
    bad_val_prompts = []
    bad_val_answers = []
    # map val_ans by question_id
    val_ans_map = {ans['question_id']: ans for ans in val_ans_data}
    for item in val_data:
        qid = item['question_id']
        text = item['text']
        opts = extract_options(text)
        if len(opts) != 10:
            bad_val_prompts.append(qid)
        ans = val_ans_map.get(qid)
        if ans:
            gpt = ans['text']
            if not any(gpt.endswith(opt) for opt in opts):
                bad_val_answers.append(qid)
    print(f"\nVal prompts with ≠10 options: {bad_val_prompts or 'None'}")
    print(f"Val samples with invalid answer: {bad_val_answers or 'None'}")

    # 10. Consistency between val.json and val_ans.json
    missing_ans = set(val_qids) - set(val_ans_map.keys())
    extra_ans   = set(val_ans_map.keys()) - set(val_qids)
    print(f"\nVal entries missing answers: {sorted(missing_ans) or ['None']}")
    print(f"Answers without val entries: {sorted(extra_ans) or ['None']}")

    # 11. Prompt field consistency
    inconsistent_prompts = []
    for ans in val_ans_data:
        qid = ans['question_id']
        val_prompt = next(item['text'] for item in val_data if item['question_id']==qid)
        if ans['prompt'] != val_prompt:
            inconsistent_prompts.append(qid)
    print(f"\nVal prompt mismatch in answer file: {inconsistent_prompts or 'None'}")

if __name__ == '__main__':
    check_json_integrity()

# %%
import json
from collections import defaultdict

# Paths
metadata_path = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
train_json    = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
val_json      = '/fs/scratch/PAS2099/Jiacheng/Texture/val/val.json'

# Load files
with open(metadata_path, 'r') as f:
    metadata = json.load(f)
with open(train_json, 'r') as f:
    train_data = json.load(f)
with open(val_json, 'r') as f:
    val_data = json.load(f)

# Count occurrences per attribute in each split
total_counts = defaultdict(int)
train_counts = defaultdict(int)
val_counts   = defaultdict(int)

for entry in metadata:
    total_counts[entry['texture attribute']] += 1
for entry in train_data:
    # extract attribute from GPT answer, e.g. "2. sand" → "sand"
    attr = entry['conversations'][1]['value'].split('. ', 1)[1]
    train_counts[attr] += 1
for entry in val_data:
    # val.json stores 'text' as the human prompt; the actual attribute is in val_ans.json
    # but we can parse train split only; for val, infer attribute from prompt similarly
    prompt = entry['text']
    # find the option number and attribute (we assume answer_id is in val_ans.json or skip val ratio)
    # for simplicity, we count val as total - train:
    pass

# Alternatively, compute val_counts = total_counts - train_counts
for attr, total in total_counts.items():
    val_counts[attr] = total - train_counts.get(attr, 0)

# Check and report
print(f"{'Attribute':20s} {'Total':>5s} {'Train':>5s} {'Val':>5s} {'Train%':>7s} {'Val%':>7s}  Status")
for attr in sorted(total_counts):
    total = total_counts[attr]
    t = train_counts.get(attr, 0)
    v = val_counts.get(attr, 0)
    exp_t = int(0.8 * total)
    exp_v = total - exp_t
    train_pct = t / total * 100
    val_pct   = v / total * 100
    status = "OK" if (t == exp_t and v == exp_v) else "MISMATCH"
    print(f"{attr:20s} {total:5d} {t:5d} {v:5d} {train_pct:6.1f}% {val_pct:6.1f}%  {status}")

# %%
import os
import json
import math
from collections import defaultdict

# 文件路径
metadata_path     = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
train_json_path   = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
val_json_path     = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'
val_ans_json_path = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val_ans.json'

# 1. 加载元数据并统计每个 attribute 的总数
with open(metadata_path, 'r') as f:
    metadata = json.load(f)
total_counts = defaultdict(int)
for entry in metadata:
    total_counts[entry['texture attribute']] += 1

# 2. 加载 train.json 并统计 train 中每个 attribute 的数目
with open(train_json_path, 'r') as f:
    train_data = json.load(f)
train_counts = defaultdict(int)
for sample in train_data:
    # GPT 回答形如 "3. sand" -> 提取属性名
    ans = sample['conversations'][1]['value']
    attr = ans.split('. ', 1)[1]
    train_counts[attr] += 1

# 3. 加载 val.json 与 val_ans.json 并统计 val 中每个 attribute 的数目
with open(val_json_path, 'r') as f:
    val_data = json.load(f)
with open(val_ans_json_path, 'r') as f:
    val_ans_data = json.load(f)
# 构建 question_id -> answer_text 映射
val_ans_map = {ans['question_id']: ans['text'].split('. ',1)[1]
               for ans in val_ans_data}
val_counts = defaultdict(int)
for sample in val_data:
    qid = sample['question_id']
    if qid in val_ans_map:
        val_counts[val_ans_map[qid]] += 1

# 4. 对比并打印
print(f"{'Attribute':25s} {'Total':>5s} {'Train':>6s} {'Val':>6s}   Status")
for attr, total in sorted(total_counts.items()):
    exp_train = math.floor(0.8 * total)
    exp_val   = total - exp_train
    act_train = train_counts.get(attr, 0)
    act_val   = val_counts.get(attr, 0)
    status    = "OK" if (act_train == exp_train and act_val == exp_val) else "MISMATCH"
    print(f"{attr:25s} {total:5d} {act_train:6d} {act_val:6d}   {status}")

# %%
import os
import json
import re
from collections import Counter, defaultdict
from PIL import Image

# Paths
train_json_path   = '/fs/scratch/PAS2099/Jiacheng/Texture/train/train.json'
val_json_path     = '/fs/scratch/PAS2099/Jiacheng/Texture/val/val.json'
val_ans_json_path = '/fs/scratch/PAS2099/Jiacheng/Texture/val/val_ans.json'
train_img_dir     = '/fs/scratch/PAS2099/Jiacheng/Texture/train/image'
val_img_dir       = '/fs/scratch/PAS2099/Jiacheng/Texture/val/image'

def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def extract_options(text):
    # parse "1. opt1, 2. opt2, ... 10. opt10"
    return re.findall(r'\d+\.\s*([^,]+)', text)

def check_image_openable(path):
    try:
        Image.open(path).verify()
        return True
    except Exception:
        return False

# load data
train_data   = load_json(train_json_path)
val_data     = load_json(val_json_path)
val_ans_data = load_json(val_ans_json_path)

# 1. ID uniqueness
train_ids = [item['id'] for item in train_data]
val_qids  = [item['question_id'] for item in val_data]
ans_qids  = [item['question_id'] for item in val_ans_data]

dup_train_ids = [i for i,c in Counter(train_ids).items() if c>1]
dup_val_qids  = [i for i,c in Counter(val_qids).items() if c>1]
dup_ans_qids  = [i for i,c in Counter(ans_qids).items() if c>1]
shared_ids    = set(train_ids) & set(val_qids)

print("1. ID uniqueness")
print("  duplicate train IDs:", dup_train_ids or "None")
print("  duplicate val question_ids:", dup_val_qids or "None")
print("  duplicate val_ans question_ids:", dup_ans_qids or "None")
print("  overlap train vs val IDs:", shared_ids or "None")
print()

# 2. Schema & field presence
def check_schema(data, required_keys):
    bad = []
    for idx,item in enumerate(data):
        if not all(k in item for k in required_keys):
            bad.append((idx, item.keys()))
    return bad

train_bad = check_schema(train_data, {'id','image','conversations'})
val_bad   = check_schema(val_data,   {'question_id','image','category','text'})
ans_bad   = check_schema(val_ans_data,{'question_id','prompt','text','answer_id','model_id','metadata'})

print("2. Schema validity")
print("  train.json bad entries:", train_bad[:3] or "None")
print("  val.json bad entries:  ", val_bad[:3] or "None")
print("  val_ans.json bad entries:", ans_bad[:3] or "None")
print()

# 3. Prompt options & answer correctness
bad_train_prompt = []
bad_train_answer = []
for item in train_data:
    human = next(c['value'] for c in item['conversations'] if c['from']=='human')
    gpt   = next(c['value'] for c in item['conversations'] if c['from']=='gpt')
    opts  = extract_options(human)
    if len(opts)!=10: bad_train_prompt.append(item['id'])
    if not any(gpt.endswith(opt) for opt in opts): bad_train_answer.append(item['id'])

val_ans_map = {a['question_id']: a for a in val_ans_data}
bad_val_prompt = []
bad_val_answer = []
for item in val_data:
    opts = extract_options(item['text'])
    if len(opts)!=10: bad_val_prompt.append(item['question_id'])
    ans = val_ans_map.get(item['question_id'],{})
    if not any(ans.get('text','').endswith(opt) for opt in opts):
        bad_val_answer.append(item['question_id'])

print("3. Prompt & answer checks")
print("  train prompts ≠10 opts:", bad_train_prompt or "None")
print("  train invalid answers:", bad_train_answer or "None")
print("  val prompts ≠10 opts:  ", bad_val_prompt or "None")
print("  val invalid answers:  ", bad_val_answer or "None")
print()

# 4. Image file existence & openability
train_imgs = [os.path.basename(i['image']) for i in train_data]
val_imgs   = [os.path.basename(i['image']) for i in val_data]

missing_train = set(train_imgs) - set(os.listdir(train_img_dir))
missing_val   = set(val_imgs)   - set(os.listdir(val_img_dir))

bad_open     = []
for fname in train_imgs:
    path = os.path.join(train_img_dir, fname)
    if os.path.exists(path) and not check_image_openable(path):
        bad_open.append(('train', fname))
for fname in val_imgs:
    path = os.path.join(val_img_dir, fname)
    if os.path.exists(path) and not check_image_openable(path):
        bad_open.append(('val', fname))

print("4. Image files")
print("  missing in train dir:", sorted(missing_train)[:5] or "None")
print("  missing in val dir:  ", sorted(missing_val) or "None")
print("  unreadable images:", bad_open or "None")
print()

# 5. Distribution coverage
from collections import Counter
train_attr_counts = Counter([next(c['value'] for c in item['conversations'] if c['from']=='gpt').split('. ',1)[1]
                             for item in train_data])
val_attr_counts   = Counter([val_ans_map[q]['text'].split('. ',1)[1]
                             for q in val_qids if q in val_ans_map])
print("5. Distribution")
print("  sample train distribution:", train_attr_counts.most_common(5))
print("  sample val distribution:  ", val_attr_counts.most_common(5))
print()

# 6. Prompt format consistency
inconsistent = []
for a in val_ans_data:
    q = next(v['text'] for v in val_data if v['question_id']==a['question_id'])
    if a['prompt']!=q:
        inconsistent.append(a['question_id'])
print("6. Prompt consistency in val_ans:", inconsistent or "None")

# %%
import os
import json
import re
from collections import defaultdict

# --- 路径配置 ---
metadata_path    = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
val_json_path    = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'
val_ans_json_path= '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val_ans.json'

# --- 加载元数据 ---
with open(metadata_path, 'r') as f:
    meta = json.load(f)

# 构建 image -> base_dataset 映射，以及 dataset -> 属性集 映射
img2ds = { e['image']: e['base_dataset'] for e in meta }
ds2attrs = defaultdict(set)
for e in meta:
    ds2attrs[e['base_dataset']].add(e['texture attribute'])

# --- 加载验证集 JSON ---
with open(val_json_path, 'r') as f:
    val_data = json.load(f)
with open(val_ans_json_path, 'r') as f:
    val_ans  = json.load(f)

# 提取 val_ans 的 question_id -> 答案 映射
ans_map = { a['question_id']: a['text'].split('. ',1)[1] for a in val_ans }

# 用来收集错误的条目
errors = []

for item in val_data:
    qid = item['question_id']
    img = item['image']
    prompt_text = item['text']

    # 1) 找到它的原始数据集
    ds = img2ds.get(img)
    if ds is None:
        errors.append((qid, img, "Unknown dataset"))
        continue

    # 2) 从 prompt 中解析所有选项
    opts = re.findall(r'\d+\.\s*([^,]+)', prompt_text)
    opts = [o.strip() for o in opts]

    # 3) 应该出现的属性集合
    expected = ds2attrs[ds]
    
    # 4) 检查：数量、遗漏、多余、重复
    opts_set = set(opts)
    missing = expected - opts_set
    extra   = opts_set - expected
    dupes   = [o for o,c in Counter(opts).items() if c>1]

    if len(opts) != len(expected) or missing or extra or dupes:
        errors.append({
            "question_id": qid,
            "image": img,
            "dataset": ds,
            "expected_count": len(expected),
            "found_count": len(opts),
            "missing": sorted(missing),
            "extra": sorted(extra),
            "duplicates": dupes
        })

# 打印结果
if not errors:
    print("All prompts correctly cover their original dataset's attributes.")
else:
    print(f"Found {len(errors)} prompt(s) with coverage issues:")
    for err in errors:
        print(err)

# %%
import os
import json
from collections import Counter, defaultdict

# Paths
metadata_path  = '/fs/scratch/PAS2099/Jiacheng/Texture_tmp/output/merged_data_v4_metadata.json'
train_json     = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
val_json       = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'

# 1. Load metadata -> image_name -> base_dataset
with open(metadata_path, 'r') as f:
    metadata = json.load(f)
img2ds = { entry['image']: entry['base_dataset'] for entry in metadata }

# 2. Load train.json and val.json -> list of image names
with open(train_json, 'r') as f:
    train_data = json.load(f)
train_imgs = [ os.path.basename(sample['image']) for sample in train_data ]

with open(val_json, 'r') as f:
    val_data = json.load(f)
val_imgs = [ sample['image'] for sample in val_data ]  # these are original filenames

# 3. Count per original_dataset
total_counts = Counter()
train_counts = Counter()
val_counts   = Counter()

for img_name, ds in img2ds.items():
    total_counts[ds] += 1

for img in train_imgs:
    ds = img2ds.get(img)
    if ds:
        train_counts[ds] += 1

for img in val_imgs:
    ds = img2ds.get(img)
    if ds:
        val_counts[ds] += 1

# 4. Print results
print(f"{'Dataset':15s} {'Total':>6s} {'Train':>6s} {'Val':>6s}")
for ds in sorted(total_counts):
    total = total_counts[ds]
    tr    = train_counts.get(ds, 0)
    vl    = val_counts.get(ds, 0)
    print(f"{ds:15s} {total:6d} {tr:6d} {vl:6d}")

# %%
import json
import re
from collections import Counter

# Paths
train_json_path = '/fs/scratch/PAS2099/Jiacheng/Texture_/train/train.json'
val_json_path   = '/fs/scratch/PAS2099/Jiacheng/Texture_/val/val.json'

# Mapping from number of options → dataset name
OPTIONS_TO_DATASET = {
    47: 'Textual_dtd',
    10: 'KTH',
    15: 'Kylberg',
}

# Helper to extract options from prompt text
def extract_options(prompt_text):
    # matches "1. opt1", "2. opt2", ...
    return re.findall(r'\d+\.\s*([^,]+)', prompt_text)

# Load JSON
with open(train_json_path, 'r') as f:
    train_data = json.load(f)
with open(val_json_path, 'r') as f:
    val_data = json.load(f)

# Count per dataset
train_counts = Counter()
val_counts   = Counter()
unknown_train = Counter()
unknown_val   = Counter()

# Process train.json entries
for item in train_data:
    # locate the human prompt
    human = next(conv['value'] for conv in item['conversations'] if conv['from']=='human')
    opts = extract_options(human)
    n_opts = len(opts)
    ds = OPTIONS_TO_DATASET.get(n_opts)
    if ds:
        train_counts[ds] += 1
    else:
        unknown_train[n_opts] += 1

# Process val.json entries
for item in val_data:
    prompt = item['text']
    opts = extract_options(prompt)
    n_opts = len(opts)
    ds = OPTIONS_TO_DATASET.get(n_opts)
    if ds:
        val_counts[ds] += 1
    else:
        unknown_val[n_opts] += 1

# Compute totals
datasets = set(OPTIONS_TO_DATASET.values())
print(f"{'Dataset':15s} {'Train':>6s} {'Val':>6s} {'Total':>6s}")
for ds in sorted(datasets):
    t = train_counts.get(ds, 0)
    v = val_counts.get(ds, 0)
    print(f"{ds:15s} {t:6d} {v:6d} {t+v:6d}")

# Report any unknowns
if unknown_train:
    print("\nUnexpected option counts in train.json:")
    for n, cnt in unknown_train.items():
        print(f"  {n} options → {cnt} samples")
if unknown_val:
    print("\nUnexpected option counts in val.json:")
    for n, cnt in unknown_val.items():
        print(f"  {n} options → {cnt} samples")

# %%
