import os
import sys
import json
import random
from collections import Counter
sys.path.insert(0, '/workspace/data')
def get_confidence_score(confidence):
    return {"yes": 0.99, "maybe": 0.5, "no": 0.01}.get(confidence, 0.01)
def process_answers(answers):
    from glossary import normalize_word
    answer_scores = {}
    for answer in answers:
        ans_text = normalize_word(answer["answer"])
        conf_score = get_confidence_score(answer["answer_confidence"])
        if ans_text in answer_scores:
            answer_scores[ans_text].append(conf_score)
        else:
            answer_scores[ans_text] = [conf_score]
    final_scores = {}
    for ans, scores in answer_scores.items():
        final_scores[ans] = sum(scores) / len(scores)
    return final_scores
def write_jsonl(items, filename):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w', encoding='utf-8') as f:
        for item in items:
            f.write(json.dumps(item) + '\n')
def generate_vqa_index():
    train_file = f"{DATASET_ROOT}/vqa.data.json"
    if os.path.exists(train_file) and os.path.getsize(train_file) > 10000:
        with open(train_file, 'r') as f:
            count = sum(1 for _ in f)
        if count > 10000:
            return True
    try:
        from transformers import XLMRobertaTokenizer
        from glossary import normalize_word
        tokenizer = XLMRobertaTokenizer(SENTENCEPIECE_MODEL)
        with open(f"{VQA_DIR}/data.json", "r") as fp:
            train_questions = json.load(fp)["questions"]
        with open(f"{VQA_DIR}/data.json", "r") as fp:
            train_annotations = json.load(fp)["annotations"]
        with open(f"{VQA_DIR}/data.json", "r") as fp:
            val_questions = json.load(fp)["questions"]
        with open(f"{VQA_DIR}/data.json", "r") as fp:
            val_annotations = json.load(fp)["annotations"]
        q_map = {}
        for q in train_questions + val_questions:
            q_map[q['question_id']] = q
        all_answers = [ann["multiple_choice_answer"] for ann in train_annotations + val_annotations]
        all_answers = [normalize_word(word) for word in all_answers]
        counter = {k: v for k, v in Counter(all_answers).items() if v >= 9}
        ans2label = {k: i for i, k in enumerate(counter.keys())}
        train_items = []
        for ann in train_annotations:
            q_id = ann['question_id']
            if q_id in q_map:
                question_text = q_map[q_id]['question']
                tokens = tokenizer.tokenize(question_text)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                answer_scores = process_answers(ann["answers"])
                labels, scores = [], []
                for answer, score in answer_scores.items():
                    if answer in ans2label:
                        labels.append(ans2label[answer])
                        scores.append(score)
                if labels:
                    train_items.append({
                        'image_path': f"train2014/COCO_train2014_{ann['image_id']:012d}.jpg",
                        'text_segment': token_ids,
                        'labels': labels,
                        'scores': scores,
                        'qid': q_id
                    })
        val_items = []
        for ann in val_annotations:
            q_id = ann['question_id']
            if q_id in q_map:
                question_text = q_map[q_id]['question']
                tokens = tokenizer.tokenize(question_text)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                answer_scores = process_answers(ann["answers"])
                labels, scores = [], []
                for answer, score in answer_scores.items():
                    if answer in ans2label:
                        labels.append(ans2label[answer])
                        scores.append(score)
                if labels:
                    val_items.append({
                        'image_path': f"val2014/COCO_val2014_{ann['image_id']:012d}.jpg",
                        'text_segment': token_ids,
                        'labels': labels,
                        'scores': scores,
                        'qid': q_id
                    })
        random.shuffle(val_items)
        rest_val = val_items[:1000]
        trainable_val = val_items[1000:]
        write_jsonl(train_items, f"{DATASET_ROOT}/vqa.data.json")
        write_jsonl(trainable_val, f"{DATASET_ROOT}/vqa_trainable_val.data.json")
        write_jsonl(rest_val, f"{DATASET_ROOT}/vqa_rest_val.data.json")
        with open(f"{DATASET_ROOT}/answer2label.data.json", 'w', encoding='utf-8') as f:
            for ans, label in ans2label.items():
                f.write(json.dumps({"answer": ans, "label": label}) + '\n')
        return True
    except Exception:
        return False
def generate_vizwiz_index():
    train_file = f"{DATASET_ROOT}/vizwiz.data.json"
    if os.path.exists(train_file) and os.path.getsize(train_file) > 10000:
        with open(train_file, 'r') as f:
            count = sum(1 for _ in f)
        if count > 5000:
            return True
    try:
        from transformers import XLMRobertaTokenizer
        from glossary import normalize_word
        tokenizer = XLMRobertaTokenizer(SENTENCEPIECE_MODEL)
        train_file_path = f"{VIZWIZ_ANNOTATIONS_DIR}/data.json"
        val_file_path = f"{VIZWIZ_ANNOTATIONS_DIR}/data.json"
        test_file_path = f"{VIZWIZ_ANNOTATIONS_DIR}/data.json"
        for file_path in [train_file_path, val_file_path, test_file_path]:
            if not os.path.exists(file_path):
                return False
        with open(train_file_path, "r") as fp:
            train_data = json.load(fp)
        with open(val_file_path, "r") as fp:
            val_data = json.load(fp)
        with open(test_file_path, "r") as fp:
            test_data = json.load(fp)
        all_answers = []
        for item in train_data + val_data:
            if "answers" in item:
                for answer in item["answers"]:
                    all_answers.append(answer["answer"])
        all_answers = [normalize_word(word) for word in all_answers]
        counter = {k: v for k, v in Counter(all_answers).items() if v >= 9}
        ans2label = {k: i for i, k in enumerate(counter.keys())}
        def process_vizwiz_split_data(data, split_name, image_dir):
            items = []
            for idx, item in enumerate(data):
                image_name = item["image"]
                question_text = item["question"]
                tokens = tokenizer.tokenize(question_text)
                token_ids = tokenizer.convert_tokens_to_ids(tokens)
                if "answers" in item and item["answers"]:
                    answer_scores = process_answers(item["answers"])
                    labels, scores = [], []
                    for answer, score in answer_scores.items():
                        if answer in ans2label:
                            labels.append(ans2label[answer])
                            scores.append(score)
                    if not labels:
                        continue
                else:
                    labels, scores = [], []
                    if split_name != "test":
                        continue
                qid = f"vizwiz_{split_name}_{idx:08d}"
                items.append({
                    'image_path': f"{image_dir}/{image_name}",
                    'text_segment': token_ids,
                    'labels': labels,
                    'scores': scores,
                    'qid': qid
                })
            return items
        train_items = process_vizwiz_split_data(train_data, "train", "vizwiz/train")
        val_items = process_vizwiz_split_data(val_data, "val", "vizwiz/val")
        test_items = process_vizwiz_split_data(test_data, "test", "vizwiz/test")
        write_jsonl(train_items, f"{DATASET_ROOT}/vizwiz.data.json")
        write_jsonl(val_items, f"{DATASET_ROOT}/vizwiz_val.data.json")
        write_jsonl(test_items, f"{DATASET_ROOT}/vizwiz_test.data.json")
        with open(f"{DATASET_ROOT}/vizwiz_answer2label.data.json", 'w', encoding='utf-8') as f:
            for ans, label in ans2label.items():
                f.write(json.dumps({"answer": ans, "label": label}) + '\n')
        return True
    except Exception:
        return False
def main():
    success_count = 0
    if generate_vqa_index():
        success_count += 1
    if generate_vizwiz_index():
        success_count += 1
    return 0 if success_count == 2 else 1
if __name__ == "__main__":
    SENTENCEPIECE_MODEL = "data/"
    DATASET_ROOT = "data/"
    VQA_DIR = "data/"
    VIZWIZ_DIR = "data/"
    VIZWIZ_ANNOTATIONS_DIR = "data/"
    exit(main())