{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369e44cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import json\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2168c8eb-6c6e-4df6-9c19-40bf1208f263",
   "metadata": {},
   "outputs": [],
   "source": [
    "Path = \"/home/user_name/U-MARVEL\"\n",
    "# 训练集候选池路径\n",
    "train_cand_path = os.path.join(Path, \"data/M-BEIR/cand_pool/global/mbeir_union_train_cand_pool.jsonl\")\n",
    "# 测试集候选池路径\n",
    "union_test_cand_pool_path = os.path.join(Path,\"data/M-BEIR/cand_pool/global/mbeir_union_test_cand_pool.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6858e895-2c04-40ae-b152-4ea132cf39fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(train_cand_path, 'r', encoding='utf-8') as f:\n",
    "    total_lines = sum(1 for _ in f)\n",
    "print(total_lines)\n",
    "with open(union_test_cand_pool_path, 'r', encoding='utf-8') as f:\n",
    "    total_lines = sum(1 for _ in f)\n",
    "print(total_lines)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "886b534e-0c93-43cd-a29c-23bca636c46f",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_names = [\n",
    "    \"mbeir_cirr_task7_test.jsonl\",\n",
    "    \"mbeir_edis_task2_test.jsonl\",\n",
    "    \"mbeir_fashion200k_task0_test.jsonl\",\n",
    "    \"mbeir_fashion200k_task3_test.jsonl\",\n",
    "    \"mbeir_fashioniq_task7_test.jsonl\",\n",
    "    \"mbeir_infoseek_task6_test.jsonl\",\n",
    "    \"mbeir_infoseek_task8_test.jsonl\",\n",
    "    \"mbeir_mscoco_task0_test.jsonl\",\n",
    "    \"mbeir_mscoco_task3_test.jsonl\",\n",
    "    \"mbeir_nights_task4_test.jsonl\",\n",
    "    \"mbeir_oven_task6_test.jsonl\",\n",
    "    \"mbeir_oven_task8_test.jsonl\",\n",
    "    \"mbeir_visualnews_task0_test.jsonl\",\n",
    "    \"mbeir_visualnews_task3_test.jsonl\",\n",
    "    \"mbeir_webqa_task1_test.jsonl\",\n",
    "    \"mbeir_webqa_task2_test.jsonl\"\n",
    "]\n",
    "\n",
    "print(len(file_names))\n",
    "total_lines = 0\n",
    "Path_temp = \"/home/user_name/U-MARVEL/data/M-BEIR/query/test\"\n",
    "for file_name in file_names:\n",
    "    file_name = os.path.join(Path_temp,file_name)\n",
    "    try:\n",
    "        with open(file_name, 'r', encoding='utf-8') as file:\n",
    "            lines = file.readlines()\n",
    "            line_count = len(lines)\n",
    "            total_lines += line_count\n",
    "            print(f\"{file_name} 的行数: {line_count}\")\n",
    "    except FileNotFoundError:\n",
    "        print(f\"错误: 文件 {file_name} 未找到。\")\n",
    "    except Exception as e:\n",
    "        print(f\"错误: 读取文件 {file_name} 时发生未知错误: {e}\")\n",
    "\n",
    "print(f\"所有文件的总行数: {total_lines}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee3b763-e8d3-49c8-ab3d-8f463c7aafdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_union_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/union_train/mbeir_union_up_train.jsonl\"\n",
    "with open(query_union_train, 'r', encoding='utf-8') as file:\n",
    "    lines = file.readlines()\n",
    "print(f\"训练集所有文件的总行数: {len(lines)}\")\n",
    "print(f\"测试集所有文件的总行数: {total_lines}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53dd9458",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载 cand 集合， did 是键，value 是 数据\n",
    "def load_jsonl(file_path):\n",
    "    \"\"\"\n",
    "    Load a JSONL file into a list of dictionaries with a progress bar.\n",
    "    \"\"\"\n",
    "    # 第一次遍历文件以统计总行数（用于进度条）\n",
    "    with open(file_path, 'r', encoding='utf-8') as f:\n",
    "        total_lines = sum(1 for _ in f)\n",
    "    \n",
    "    # 第二次遍历文件并解析数据\n",
    "    data = {}\n",
    "    with open(file_path, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f, total=total_lines, desc=\"Loading JSONL\"):\n",
    "            item = json.loads(line)\n",
    "            data[item['did']] = item.copy()\n",
    "    print(f\"Loaded {len(data)} records from {file_path}\")\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecbaab48-a3b6-4af3-a631-8aa01c9705b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载 query 集合，qid 是键，value 是正负样本的 did 合并列表\n",
    "def load_jsonl_query(file_path):\n",
    "    \"\"\"\n",
    "    Load a JSONL file into a list of dictionaries with a progress bar.\n",
    "    \"\"\"\n",
    "    # 第一次遍历文件以统计总行数（用于进度条）\n",
    "    with open(file_path, 'r', encoding='utf-8') as f:\n",
    "        total_lines = sum(1 for _ in f)\n",
    "    \n",
    "    # 第二次遍历文件并解析数据\n",
    "    data = {}\n",
    "    with open(file_path, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f, total=total_lines, desc=\"Loading JSONL\"):\n",
    "            item = json.loads(line)\n",
    "            pos_cand_list = item[\"pos_cand_list\"]\n",
    "            neg_cand_list = item[\"neg_cand_list\"]\n",
    "            qid = item[\"qid\"]\n",
    "            data[qid] = pos_cand_list[:] + neg_cand_list[:]\n",
    "    print(f\"Loaded {len(data)} records from {file_path}\")\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81fa5a95-7323-4ecf-bd10-0ad8eaaecfcc",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 检查训练集的候选池和测试集候选池互相的包含情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1902f4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the training candidate pool\n",
    "train_cand_pool = load_jsonl(train_cand_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ada2ad0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "union_test_cand_pool = load_jsonl(union_test_cand_pool_path)\n",
    "print(f\"Union test candidate pool size: {len(union_test_cand_pool)}\")\n",
    "# Union test candidate pool size: 5609079"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba120750",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 检验 train_cand_pool 当中的数据是否在 union_test_cand_pool 当中\n",
    "def check_candidates_in_union_test(train_cand_pool, union_test_cand_pool):\n",
    "    \"\"\"\n",
    "    Check if all candidates in the training candidate pool are present in the union test candidate pool.\n",
    "    \"\"\"\n",
    "    not_in_union_test = []\n",
    "    for key in tqdm(train_cand_pool.keys(), desc=\"Checking candidates\"):\n",
    "        if key not in union_test_cand_pool:\n",
    "            not_in_union_test.append(key)\n",
    "    \n",
    "    print(f\"Number of candidates not in union test: {len(not_in_union_test)}\")\n",
    "    return not_in_union_test\n",
    "# Check candidates in union test\n",
    "not_in_union_test = check_candidates_in_union_test(train_cand_pool, union_test_cand_pool)\n",
    "print(len(set(not_in_union_test)))\n",
    "\n",
    "\n",
    "# Number of candidates not in union test: 614353\n",
    "# 614353"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "085d9dce-623c-4745-9404-ca86d98a0f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_in_union_test[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4a58b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7f0f76d6-f765-40b4-8eac-3b78e0190f96",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 检查训练集当中非 union 的 query 的 qid 以及正负样本的 did 是否在 union 候选池当中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f21ef6a-4af2-4d51-92c9-1cc7108cf852",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_train_list = [\n",
    "    \"mbeir_cirr_train.jsonl\",\n",
    "    \"mbeir_edis_train.jsonl\",\n",
    "    \"mbeir_fashion200k_train.jsonl\",\n",
    "    \"mbeir_fashioniq_train.jsonl\",\n",
    "    \"mbeir_infoseek_train.jsonl\",\n",
    "    \"mbeir_mscoco_train.jsonl\",\n",
    "    \"mbeir_nights_train.jsonl\",\n",
    "    \"mbeir_oven_train.jsonl\",\n",
    "    \"mbeir_visualnews_train.jsonl\",\n",
    "    \"mbeir_webqa_train.jsonl\"\n",
    "]\n",
    "Path_query_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/train\"\n",
    "query_train_list = [os.path.join(Path_query_train, path) for path in query_train_list]\n",
    "print(\"query_train_list:\", query_train_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0945b94-6ebd-4e22-80f5-5f5712663573",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 收集 query_train_list 当中所有的 pos_cand_list, neg_cand_list, qid\n",
    "query_train_pos_cand_list = []\n",
    "query_train_qid_list = []\n",
    "query_train_neg_cand_list = []\n",
    "for query_train_path in query_train_list:\n",
    "    with open(query_train_path, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f, desc=f\"Loading {query_train_path}\"):\n",
    "            item = json.loads(line)\n",
    "            qid = item['qid']\n",
    "            pos_cand_list = item['pos_cand_list']\n",
    "            neg_cand_list = item[\"neg_cand_list\"]\n",
    "            query_train_qid_list.append(qid)\n",
    "            query_train_pos_cand_list.extend(pos_cand_list)\n",
    "            query_train_neg_cand_list.extend(neg_cand_list)\n",
    "print(f\"Total number of pos candidates in query train list: {len(query_train_pos_cand_list)} {len(set(query_train_pos_cand_list))}\")\n",
    "print(f\"Total number of neg candidates in query train list: {len(query_train_neg_cand_list)} {len(set(query_train_neg_cand_list))}\")\n",
    "print(f\"Total number of qid in query train list: {len(query_train_qid_list)} {len(set(query_train_qid_list))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "775f4906-dd3d-4fe9-8c8f-8369b259ae7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 收集 train_cand_path 当中所有的 did\n",
    "train_cand_did_list = []\n",
    "train_cand_pool = load_jsonl(train_cand_path)\n",
    "train_cand_did_list = list(train_cand_pool.keys())\n",
    "print(f\"Total number of candidates in train cand pool: {len(train_cand_did_list)} {len(set(train_cand_did_list))}\")\n",
    "query_train_pos_cand_set = set(query_train_pos_cand_list)\n",
    "query_train_neg_cand_set = set(query_train_neg_cand_list)\n",
    "train_cand_did_set = set(train_cand_did_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e198673-38c2-4f64-91f4-fd12fb7fc60b",
   "metadata": {},
   "source": [
    "#### 检查 query 当中的正样本 did 是否在候选池当中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8693387c-5afc-4147-ae12-16bca3215365",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 交集\n",
    "intersection = query_train_pos_cand_set.intersection(train_cand_did_set)\n",
    "print(f\"Number of candidates in both query train list and train cand pool: {len(intersection)}\")\n",
    "# 差集，即在 query_train_pos_cand_set 中但不在 train_cand_did_set 中的元素\n",
    "difference = query_train_pos_cand_set.difference(train_cand_did_set)\n",
    "print(f\"Number of candidates in query train list but not in train cand pool: {len(difference)}\")\n",
    "# 差集，即在 train_cand_did_set 中但不在 query_train_pos_cand_set 中的元素\n",
    "difference = train_cand_did_set.difference(query_train_pos_cand_set)\n",
    "print(f\"Number of candidates in train cand pool but not in query train list: {len(difference)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb48e417-a8d4-40f8-a28c-c4c31dad091c",
   "metadata": {},
   "source": [
    "#### 检查 query 当中的负样本 did 是否在候选池当中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ff4ae9a-9520-4a6b-a6b1-a53de84748e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 交集\n",
    "intersection = query_train_neg_cand_set.intersection(train_cand_did_set)\n",
    "print(f\"Number of candidates in both query train list and train cand pool: {len(intersection)}\")\n",
    "# 差集，即在 query_train_neg_cand_set 中但不在 train_cand_did_set 中的元素\n",
    "difference = query_train_neg_cand_set.difference(train_cand_did_set)\n",
    "print(f\"Number of candidates in query train list but not in train cand pool: {len(difference)}\")\n",
    "# 差集，即在 train_cand_did_set 中但不在 query_train_neg_cand_set 中的元素\n",
    "difference = train_cand_did_set.difference(query_train_neg_cand_set)\n",
    "print(f\"Number of candidates in train cand pool but not in query train list: {len(difference)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6a5ca3-5102-4957-aef5-b8e0383a4a85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "94d82732-66df-4275-a6a5-8f8641fe747c",
   "metadata": {},
   "source": [
    "#### 检验非 union 集合的 qid 和 union 集合的 qid 情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "069e7d4f-24c1-4bf4-954a-191be0bbff3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_union_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/union_train/mbeir_union_up_train.jsonl\"\n",
    "with open(query_union_train, 'r', encoding='utf-8') as f:\n",
    "    for line in tqdm(f, desc=f\"Loading {query_train_path}\"):\n",
    "        item = json.loads(line)\n",
    "        print(item)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c845a202-2107-4fae-b83b-fb5011b7c91d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 收集 query_union_train 当中所有的 qid\n",
    "query_union_train_qid_ist = []\n",
    "with open(query_union_train, 'r', encoding='utf-8') as f:\n",
    "    for line in tqdm(f, desc=f\"Loading {query_union_train}\"):\n",
    "        item = json.loads(line)\n",
    "        qid = item['qid']\n",
    "        query_union_train_qid_ist.append(qid)\n",
    "print(f\"Total number of qid in query union train list: {len(query_union_train_qid_ist)} {len(set(query_union_train_qid_ist))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bdb53c2-9c3b-48cf-b3d5-484958c90fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 检验 query_train_qid_list 是否在 query_union_train_qid_ist 当中\n",
    "query_train_qid_set = set(query_train_qid_list)\n",
    "query_union_train_qid_set = set(query_union_train_qid_ist)\n",
    "# 交集\n",
    "intersection = query_train_qid_set.intersection(query_union_train_qid_set)\n",
    "print(f\"Number of qid in both query train list and query union train list: {len(intersection)}\")\n",
    "# 差集，即在 query_train_qid_set 中但不在 query_union_train_qid_set 中的元素\n",
    "difference = query_train_qid_set.difference(query_union_train_qid_set)\n",
    "print(f\"Number of qid in query train list but not in query union train list: {len(difference)}\")\n",
    "# 差集，即在 query_union_train_qid_set 中但不在 query_train_qid_set 中的元素\n",
    "difference = query_union_train_qid_set.difference(query_train_qid_set)\n",
    "print(f\"Number of qid in query union train list but not in query train list: {len(difference)}\")\n",
    "# 打印 query_union_train_qid_ist 当中重复的 qid\n",
    "from collections import Counter\n",
    "from collections import defaultdict\n",
    "qid_counts = Counter(query_union_train_qid_ist)\n",
    "duplicates = [qid for qid, count in qid_counts.items() if count > 1]\n",
    "print(f\"Number of duplicate qid in query union train list: {len(duplicates)}\")\n",
    "print(f\"Duplicate qid in query union train list: {duplicates[:10]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "824cbb26-cd3a-42cd-b455-b58e2a86f0cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "019f7179-82cc-4594-a46f-3a1e35970e72",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "cba7bcf0-2186-42a0-9546-1440ab1b6121",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 对训练集进行拆分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68185082-e259-4ff3-bbfd-4c412ab3d566",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "Path_query_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/train/query_train\"\n",
    "# 任务映射\n",
    "dataset_to_val_data_file_middle_name_map = {\n",
    "            \"VisualNews\": [\"visualnews_task0\", \"visualnews_task3\"],\n",
    "            \"MSCOCO\": [\"mscoco_task0\", \"mscoco_task3\"],\n",
    "            \"Fashion200K\": [\"fashion200k_task0\", \"fashion200k_task3\"],\n",
    "            \"WebQA\": [\"webqa_task1\", \"webqa_task2\"],\n",
    "            \"EDIS\": [\"edis_task2\"],\n",
    "            \"NIGHTS\": [\"nights_task4\"],\n",
    "            \"OVEN\": [\"oven_task6\", \"oven_task8\"],\n",
    "            \"INFOSEEK\": [\"infoseek_task6\", \"infoseek_task8\"],\n",
    "            \"FashionIQ\": [\"fashioniq_task7\"],\n",
    "            \"CIRR\": [\"cirr_task7\"],\n",
    "        }\n",
    "\n",
    "# Mapping of dataset names to IDs\n",
    "DATASET_IDS = {\n",
    "    \"visualnews\": 0,     \"fashion200k\": 1,   \"webqa\": 2,      \"edis\": 3,\n",
    "    \"nights\": 4,         \"oven\": 5,          \"infoseek\": 6,\n",
    "    \"fashioniq\": 7,      \"cirr\": 8,          \"mscoco\": 9,\n",
    "}\n",
    "\n",
    "MBEIR_TASK = {\n",
    "    \"text -> image\": 0,      \"text -> text\": 1,             \"text -> image,text\": 2,\n",
    "    \"image -> text\": 3,      \"image -> image\": 4,           \"image -> text,image\": 5,\n",
    "    \"image,text -> text\": 6, \"image,text -> image\": 7,      \"image,text -> image,text\": 8,\n",
    "}\n",
    "# \"./data/M-BEIR/query/train/mbeir_train_visualnews_task3.jsonl\",\n",
    "query_train_file_list = []\n",
    "for dataset,datatset_task_name in dataset_to_val_data_file_middle_name_map.items():\n",
    "    for task_name in datatset_task_name:\n",
    "        query_train_file_list.append(os.path.join(Path_query_train, f\"mbeir_train_{task_name}.jsonl\"))\n",
    "print(\"query_train_file_list[0]:\", query_train_file_list[0])\n",
    "print(len(query_train_file_list))\n",
    "query_train_file_list2datasetidtaskid = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d3a86c8-6958-480e-807c-633a0029f5c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for query_train_file in query_train_file_list:\n",
    "    task_id = int(query_train_file.split(\"/\")[-1].split(\"_\")[-1].split(\".\")[0][-1])\n",
    "    datasetid = DATASET_IDS[query_train_file.split(\"/\")[-1].split(\"_\")[-2]]\n",
    "    query_train_file_list2datasetidtaskid[(datasetid, task_id)] = query_train_file\n",
    "print(query_train_file_list2datasetidtaskid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a147ccd-6870-4d44-ae35-d15b3b9f436e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "datasetidtaskid2querylist = defaultdict(list)\n",
    "query_union_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/union_train/mbeir_union_up_train.jsonl\"\n",
    "with open(query_union_train, 'r', encoding='utf-8') as f:\n",
    "    for line in tqdm(f, desc=f\"Loading {query_union_train}\"):\n",
    "        item = json.loads(line)\n",
    "        datasetid = int(item['qid'].split(\":\")[0])\n",
    "        task_id = int(item[\"task_id\"])\n",
    "        datasetidtaskid2querylist[(datasetid, task_id)].append(item)\n",
    "print(f\"Total number of datasetidtaskid2querylist: {len(datasetidtaskid2querylist)}\")\n",
    "# 打印 datasetidtaskid2querylist 当中的数据\n",
    "for key, value in datasetidtaskid2querylist.items():\n",
    "    print(f\"Key: {key}, Value: {len(value)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b248969-c380-42e2-b6ec-b3b4f7a87d4f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 根据 datasetidtaskid2querylist 当中的数据和 query_train_file_list2datasetidtaskid 的文件路径，生成新的jsonl文件\n",
    "for key, value in datasetidtaskid2querylist.items():\n",
    "    datasetid, task_id = key\n",
    "    query_train_file = query_train_file_list2datasetidtaskid[key]\n",
    "    print(f\"Key: {key}, Value: {len(value)}, query_train_file: {query_train_file}\")\n",
    "    # 将 value 当中的数据写入 query_train_file 当中\n",
    "    with open(query_train_file, 'w', encoding='utf-8') as f:\n",
    "        for item in value:\n",
    "            f.write(json.dumps(item, ensure_ascii=False) + '\\n')\n",
    "    # print(f\"Write {len(value)} records to {query_train_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8307da68-b78f-429b-baca-ca81a67c3ab3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取 query_train_file_list2datasetidtaskid 文件，获取 pos_cand_list 和 neg_cand_list\n",
    "query_train_file_list2pos_cand_list = defaultdict(set)\n",
    "query_train_file_list2neg_cand_list = defaultdict(set)\n",
    "query_train_file_list2datasetid = defaultdict(set)\n",
    "for key,query_train_file in query_train_file_list2datasetidtaskid.items():\n",
    "    with open(query_train_file, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f):\n",
    "            item = json.loads(line)\n",
    "            pos_cand_list = item['pos_cand_list']\n",
    "            neg_cand_list = item['neg_cand_list']\n",
    "            query_train_file_list2pos_cand_list[query_train_file].update(set(pos_cand_list))\n",
    "            query_train_file_list2neg_cand_list[query_train_file].update(set(neg_cand_list))\n",
    "            for did in pos_cand_list+neg_cand_list:\n",
    "                datasetid = int(did.split(\":\")[0])\n",
    "                query_train_file_list2datasetid[query_train_file].add(datasetid)\n",
    "    print(f\"Total number of datasetid in {query_train_file}: {(query_train_file_list2datasetid[query_train_file])}\")\n",
    "print(query_train_file_list2datasetid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81e83733-c3e4-4f4e-8206-6af0635f5e6a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# 检查 query_train_file_list2pos_cand_list 和 query_train_file_list2neg_cand_list 的 modality 是否一致\n",
    "train_cand_pool = load_jsonl(train_cand_path)\n",
    "\n",
    "# 检查正样本候选列表\n",
    "print(\"Checking positive candidates...\")\n",
    "for query_train_file, pos_cand_set in query_train_file_list2pos_cand_list.items():\n",
    "    modality = None\n",
    "    print(query_train_file,\"正样本数量: \",len(pos_cand_set))\n",
    "    for did in tqdm(pos_cand_set):\n",
    "        if modality is None:\n",
    "            modality = train_cand_pool[did][\"modality\"]\n",
    "        else:\n",
    "            if modality != train_cand_pool[did][\"modality\"]:\n",
    "                print(f\"\\nInconsistency found in positive candidates for query: {query_train_file}\")\n",
    "                print(f\"Modality: {modality}, {train_cand_pool[did]['modality']}\")\n",
    "                print(train_cand_pool[did])\n",
    "                break\n",
    "    neg_modality_num = 0\n",
    "    neg_cand_set = query_train_file_list2neg_cand_list[query_train_file]\n",
    "    print(query_train_file,\"负样本数量: \",len(neg_cand_set))\n",
    "    for did in tqdm(neg_cand_set):\n",
    "        if modality is None:\n",
    "            modality = train_cand_pool[did][\"modality\"]\n",
    "        else:\n",
    "            if modality != train_cand_pool[did][\"modality\"]:\n",
    "                neg_modality_num +=1\n",
    "                # print(f\"\\nInconsistency found in negative candidates for query: {query_train_file}\")\n",
    "                # print(f\"Modality: {modality}, {train_cand_pool[did]['modality']}\")\n",
    "                # print(train_cand_pool[did])\n",
    "                # break\n",
    "    print(\"负样本不匹配模态的数量: \",neg_modality_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3870862e-0f07-4e7a-9767-fc915e3fe9a8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 检查负样本候选列表\n",
    "print(\"\\nChecking negative candidates...\")\n",
    "for query_train_file, neg_cand_set in query_train_file_list2neg_cand_list.items():\n",
    "    modality = None\n",
    "    neg_modality_num = 0\n",
    "    print(query_train_file,len(neg_cand_set))\n",
    "    for did in tqdm(neg_cand_set):\n",
    "        if modality is None:\n",
    "            modality = train_cand_pool[did][\"modality\"]\n",
    "        else:\n",
    "            if modality != train_cand_pool[did][\"modality\"]:\n",
    "                neg_modality_num +=1\n",
    "                print(f\"\\nInconsistency found in negative candidates for query: {query_train_file}\")\n",
    "                print(f\"Modality: {modality}, {train_cand_pool[did]['modality']}\")\n",
    "                print(train_cand_pool[did])\n",
    "                break\n",
    "    print(\"不匹配模态的数量: \",neg_modality_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f49219d7-6c20-441b-b42b-700a6a5642de",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dbf6b95-97f2-4166-8eca-494898632f8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3799d2ae-02f1-4b13-a7d8-0d545d9b2355",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### 对 cand 进行拆分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c44b914a-bddf-4b6f-979a-47eb624c9d00",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "Path_cand_train = \"/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train\"\n",
    "# 任务映射\n",
    "dataset_to_val_data_file_middle_name_map = {\n",
    "            \"VisualNews\": [\"visualnews_task0\", \"visualnews_task3\"],\n",
    "            \"MSCOCO\": [\"mscoco_task0\", \"mscoco_task3\"],\n",
    "            \"Fashion200K\": [\"fashion200k_task0\", \"fashion200k_task3\"],\n",
    "            \"WebQA\": [\"webqa_task1\", \"webqa_task2\"],\n",
    "            \"EDIS\": [\"edis_task2\"],\n",
    "            \"NIGHTS\": [\"nights_task4\"],\n",
    "            \"OVEN\": [\"oven_task6\", \"oven_task8\"],\n",
    "            \"INFOSEEK\": [\"infoseek_task6\", \"infoseek_task8\"],\n",
    "            \"FashionIQ\": [\"fashioniq_task7\"],\n",
    "            \"CIRR\": [\"cirr_task7\"],\n",
    "        }\n",
    "\n",
    "# Mapping of dataset names to IDs\n",
    "DATASET_IDS = {\n",
    "    \"visualnews\": 0,     \"fashion200k\": 1,   \"webqa\": 2,      \"edis\": 3,\n",
    "    \"nights\": 4,         \"oven\": 5,          \"infoseek\": 6,\n",
    "    \"fashioniq\": 7,      \"cirr\": 8,          \"mscoco\": 9,\n",
    "}\n",
    "\n",
    "MBEIR_TASK = {\n",
    "    0: \"text -> image\",      1: \"text -> text\",             2: \"text -> image,text\",\n",
    "    3: \"image -> text\",      4: \"image -> image\",           5:\"image -> text,image\",\n",
    "    6: \"image,text -> text\", 7:\"image,text -> image\",       8:\"image,text -> image,text\",\n",
    "}\n",
    "# \"./data/M-BEIR/query/train/mbeir_cirr_task7_cand_pool.jsonl\",\n",
    "cand_train_file_list = []\n",
    "for dataset,datatset_task_name in dataset_to_val_data_file_middle_name_map.items():\n",
    "    for task_name in datatset_task_name:\n",
    "        cand_train_file_list.append(os.path.join(Path_cand_train, f\"mbeir_train_{task_name}_cand_pool.jsonl\"))\n",
    "print(\"cand_train_file_list[0]:\", cand_train_file_list[0])\n",
    "print(len(cand_train_file_list))\n",
    "print(cand_train_file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a55b6fa8-9543-4477-951d-fd8cc47af649",
   "metadata": {},
   "outputs": [],
   "source": [
    "cand_train_file_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7e5ddc-0289-41b8-8659-614e30815113",
   "metadata": {},
   "outputs": [],
   "source": [
    "cand_train_file_list2datasetid = defaultdict(set)\n",
    "for cand_train_file in cand_train_file_list:\n",
    "    task_id = int(cand_train_file.split(\"/\")[-1].split(\"_\")[-3].split(\".\")[0][-1])\n",
    "    datasetid = DATASET_IDS[cand_train_file.split(\"/\")[-1].split(\"_\")[-4]]\n",
    "    cand_train_file_list2datasetid[cand_train_file] = {datasetid}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d01c53ad-8200-4ba1-b94c-c498c253335e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(cand_train_file_list2datasetid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc668d7b-2292-40a4-9e62-26db3291ed3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cand_train_file_list2datasetid = {\n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_visualnews_task0_cand_pool.jsonl': {0}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_visualnews_task3_cand_pool.jsonl': {0}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_mscoco_task0_cand_pool.jsonl': {9}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_mscoco_task3_cand_pool.jsonl': {9}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_fashion200k_task0_cand_pool.jsonl': {1}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_fashion200k_task3_cand_pool.jsonl': {1}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_webqa_task1_cand_pool.jsonl': {2}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_webqa_task2_cand_pool.jsonl': {2}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_edis_task2_cand_pool.jsonl': {3}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_nights_task4_cand_pool.jsonl': {4}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_oven_task6_cand_pool.jsonl': {5,6}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_oven_task8_cand_pool.jsonl': {5,6}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_infoseek_task6_cand_pool.jsonl': {5,6}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_infoseek_task8_cand_pool.jsonl': {5,6}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_fashioniq_task7_cand_pool.jsonl': {7}, \n",
    "    '/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_train/mbeir_train_cirr_task7_cand_pool.jsonl': {8}\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fc597a2-1afe-4f9a-8b49-4c29cc363342",
   "metadata": {},
   "outputs": [],
   "source": [
    "Path = \"/home/user_name/U-MARVEL\"\n",
    "train_cand_path = os.path.join(Path, \"data/M-BEIR/cand_pool/global/mbeir_union_train_cand_pool.jsonl\")\n",
    "train_cand_pool = {}\n",
    "datasetid2did = defaultdict(set)\n",
    "with open(train_cand_path, 'r', encoding='utf-8') as f:\n",
    "    for line in tqdm(f):\n",
    "        item = json.loads(line)\n",
    "        did = item['did']\n",
    "        datasetid = int(did.split(\":\")[0])\n",
    "        train_cand_pool[did] = item.copy()\n",
    "        datasetid2did[datasetid].add(did)\n",
    "    \n",
    "print(f\"Total number of datasetid in train cand pool: {len(datasetid2did)}\")\n",
    "print(f\"Total number of did in train cand pool: {len(train_cand_pool)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b31747f0-7a6a-4efe-b51c-d1dce9cdbf08",
   "metadata": {},
   "outputs": [],
   "source": [
    "num = 0\n",
    "for datasetid,did in datasetid2did.items():\n",
    "    num += len(did)\n",
    "    print(datasetid,len(did))\n",
    "print(num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e8f27e7-6925-4a97-bb4d-fa6d8c9267c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据 cand_train_file_list2datasetid，datasetid2did， train_cand_pool 生成新的jsonl文件\n",
    "from tqdm import tqdm\n",
    "for cand_train_file,datasetid in cand_train_file_list2datasetid.items():\n",
    "    # 将 datasetid 当中的数据写入 cand_train_file 当中\n",
    "    task_id = int(cand_train_file.split(\"/\")[-1].split(\"_\")[-3].split(\".\")[0][-1])\n",
    "    task_name = MBEIR_TASK[task_id]\n",
    "    task_modality = task_name.split(\" -> \")[1]\n",
    "    with open(cand_train_file, 'w', encoding='utf-8') as f:\n",
    "        for item in datasetid:\n",
    "            num = 0\n",
    "            for did in tqdm(list(datasetid2did[item])):\n",
    "                if did in train_cand_pool:\n",
    "                    item_data = train_cand_pool[did]\n",
    "                    modality = item_data[\"modality\"]\n",
    "                    if modality == task_modality:\n",
    "                        f.write(json.dumps(train_cand_pool[did], ensure_ascii=False) + '\\n')\n",
    "                        num += 1\n",
    "                else:\n",
    "                    print(f\"did: {did} not in train_cand_pool\")\n",
    "            print(task_modality)\n",
    "            print(f\"cand_train_file: {cand_train_file}, datasetid: {datasetid}\")\n",
    "            print(f\"Write {num} of {len(datasetid2did[item])} records to {cand_train_file}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "159e8d3a-6243-49ad-ab78-9d7450dcd57f",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### 检验 16 个文件的正负样本 did 都在对应的候选池当中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13a51b0c-bf3c-4b9f-884d-b976cf1a18ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 建立 query_train_file_list 和 cand_train_file_list 的映射关系\n",
    "query_train_file_list2cand_train_file_list = {}\n",
    "for query_train_file in query_train_file_list:\n",
    "    dataset = query_train_file.split(\"/\")[-1].split(\"_\")[2]\n",
    "    task = query_train_file.split(\"/\")[-1].split(\"_\")[3].split(\".\")[0]\n",
    "    # print(query_train_file,dataset,task)\n",
    "    for cand_train_file in cand_train_file_list:\n",
    "        # print(cand_train_file)\n",
    "        if dataset in cand_train_file and task in cand_train_file:\n",
    "            query_train_file_list2cand_train_file_list[query_train_file] = cand_train_file\n",
    "            break\n",
    "print(query_train_file_list2cand_train_file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "530c97db-aa8d-4d31-85f2-555579c10bf3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 根据映射关系读取文件，判断 query_train_file_list 当中的 pos_cand_list 和 neg_cand_list 是否在 cand_train_file_list 当中\n",
    "for query_train_file,cand_train_file in query_train_file_list2cand_train_file_list.items():\n",
    "    print(f\"query_train_file: {query_train_file}, cand_train_file: {cand_train_file}\")\n",
    "    # 读取 cand_train_file\n",
    "    cand_train_did_set = set()\n",
    "    with open(cand_train_file, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f):\n",
    "            item = json.loads(line)\n",
    "            did = item['did']\n",
    "            cand_train_did_set.add(did)\n",
    "    print(f\"Total number of did in {cand_train_file}: {len(cand_train_did_set)}\")\n",
    "    # 读取 query_train_file\n",
    "    with open(query_train_file, 'r', encoding='utf-8') as f:\n",
    "        for line in tqdm(f):\n",
    "            item = json.loads(line)\n",
    "            pos_cand_list = item['pos_cand_list']\n",
    "            neg_cand_list = item['neg_cand_list']\n",
    "            # 检查 pos_cand_list 和 neg_cand_list 是否在 cand_train_did_set 当中\n",
    "            for did in pos_cand_list:\n",
    "                if did not in cand_train_did_set:\n",
    "                    print(f\"did: {did} not in cand_train_did_set\")\n",
    "            # for did in neg_cand_list:\n",
    "            #     if did not in cand_train_did_set:\n",
    "            #         print(f\"did: {did} not in cand_train_did_set\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2de4dd5e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7889ce94",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e227ffb3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bfbf10ff",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### 检查测试集候选池的模态是否全部一致"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bff4749-9e04-43ef-9e6c-10f851ab8081",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5649c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "test_cand_pool_files = [\n",
    "    \"mbeir_cirr_task7_cand_pool.jsonl\",\n",
    "    \"mbeir_edis_task2_cand_pool.jsonl\",\n",
    "    \"mbeir_fashion200k_task0_cand_pool.jsonl\",\n",
    "    \"mbeir_fashion200k_task3_cand_pool.jsonl\",\n",
    "    \"mbeir_fashioniq_task7_cand_pool.jsonl\",\n",
    "    \"mbeir_infoseek_task6_cand_pool.jsonl\",\n",
    "    \"mbeir_infoseek_task8_cand_pool.jsonl\",\n",
    "    \"mbeir_mscoco_task0_test_cand_pool.jsonl\",\n",
    "    \"mbeir_mscoco_task0_val_cand_pool.jsonl\",\n",
    "    \"mbeir_mscoco_task3_test_cand_pool.jsonl\",\n",
    "    \"mbeir_mscoco_task3_val_cand_pool.jsonl\",\n",
    "    \"mbeir_nights_task4_cand_pool.jsonl\",\n",
    "    \"mbeir_oven_task6_cand_pool.jsonl\",\n",
    "    \"mbeir_oven_task8_cand_pool.jsonl\",\n",
    "    \"mbeir_visualnews_task0_cand_pool.jsonl\",\n",
    "    \"mbeir_visualnews_task3_cand_pool.jsonl\",\n",
    "    \"mbeir_webqa_task1_cand_pool.jsonl\",\n",
    "    \"mbeir_webqa_task2_cand_pool.jsonl\"\n",
    "]\n",
    "Path_cand_test = \"/home/user_name/U-MARVEL/data/M-BEIR/cand_pool/local/cand_test\"\n",
    "test_cand_pool_files = [os.path.join(Path_cand_test, path) for path in test_cand_pool_files]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ad6735",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "\n",
    "def count_modalities(jsonl_path):\n",
    "    # modality_counter = Counter()\n",
    "    modality_counter = defaultdict(int)\n",
    "    with open(jsonl_path, 'r') as file:\n",
    "        for line in tqdm(file):\n",
    "            try:\n",
    "                data = json.loads(line.strip())\n",
    "                modality = data.get('modality')\n",
    "                if modality:\n",
    "                    modality_counter[modality] += 1\n",
    "            except json.JSONDecodeError:\n",
    "                print(f\"警告：无法解析行: {line}\")\n",
    "    return modality_counter\n",
    "for jsonl_path in test_cand_pool_files:\n",
    "    modality_counts = count_modalities(jsonl_path)\n",
    "    print(jsonl_path,\"Modality 统计结果:\")\n",
    "    for mod, count in modality_counts.items():\n",
    "        print(f\"- {mod}: {count}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ef8409",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdc84d25",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edacf854",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3938dd7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "681b0411",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "### 拆分 qrels 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e10849",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "Path_qrels_train = \"/home/user_name/U-MARVEL/data/M-BEIR/qrels/train\"\n",
    "# 任务映射\n",
    "dataset_to_val_data_file_middle_name_map = {\n",
    "            \"VisualNews\": [\"visualnews_task0\", \"visualnews_task3\"],\n",
    "            \"MSCOCO\": [\"mscoco_task0\", \"mscoco_task3\"],\n",
    "            \"Fashion200K\": [\"fashion200k_task0\", \"fashion200k_task3\"],\n",
    "            \"WebQA\": [\"webqa_task1\", \"webqa_task2\"],\n",
    "            \"EDIS\": [\"edis_task2\"],\n",
    "            \"NIGHTS\": [\"nights_task4\"],\n",
    "            \"OVEN\": [\"oven_task6\", \"oven_task8\"],\n",
    "            \"INFOSEEK\": [\"infoseek_task6\", \"infoseek_task8\"],\n",
    "            \"FashionIQ\": [\"fashioniq_task7\"],\n",
    "            \"CIRR\": [\"cirr_task7\"],\n",
    "        }\n",
    "\n",
    "# Mapping of dataset names to IDs\n",
    "DATASET_IDS = {\n",
    "    \"visualnews\": 0,     \"fashion200k\": 1,   \"webqa\": 2,      \"edis\": 3,\n",
    "    \"nights\": 4,         \"oven\": 5,          \"infoseek\": 6,\n",
    "    \"fashioniq\": 7,      \"cirr\": 8,          \"mscoco\": 9,\n",
    "}\n",
    "\n",
    "MBEIR_TASK = {\n",
    "    0: \"text -> image\",      1: \"text -> text\",             2: \"text -> image,text\",\n",
    "    3: \"image -> text\",      4: \"image -> image\",           5:\"image -> text,image\",\n",
    "    6: \"image,text -> text\", 7:\"image,text -> image\",       8:\"image,text -> image,text\",\n",
    "}\n",
    "# \"./data/M-BEIR/query/train/mbeir_train_cirr_task7_qrels.txt\",\n",
    "qrels_train_file_list = []\n",
    "for dataset,datatset_task_name in dataset_to_val_data_file_middle_name_map.items():\n",
    "    for task_name in datatset_task_name:\n",
    "        qrels_train_file_list.append(os.path.join(Path_qrels_train, f\"mbeir_train_{task_name}_qrels.txt\"))\n",
    "print(\"qrels_train_file_list[0]:\", qrels_train_file_list[0])\n",
    "print(len(qrels_train_file_list))\n",
    "print(qrels_train_file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f08db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "old_qrels_file_list = [\n",
    "    \"mbeir_cirr_train_qrels.txt\",\n",
    "    \"mbeir_edis_train_qrels.txt\",\n",
    "    \"mbeir_fashion200k_train_qrels.txt\",\n",
    "    \"mbeir_fashioniq_train_qrels.txt\",\n",
    "    \"mbeir_infoseek_train_qrels.txt\",\n",
    "    \"mbeir_mscoco_train_qrels.txt\",\n",
    "    \"mbeir_nights_train_qrels.txt\",\n",
    "    \"mbeir_oven_train_qrels.txt\",\n",
    "    \"mbeir_visualnews_train_qrels.txt\",\n",
    "    \"mbeir_webqa_train_qrels.txt\"\n",
    "]\n",
    "Path_qrels_train_old = \"/home/user_name/U-MARVEL/data/M-BEIR/qrels/train\"\n",
    "old_qrels_file_list = [os.path.join(Path_qrels_train_old, path) for path in old_qrels_file_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7f29219",
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "# shutil.copy2(source_file, new_file_path)\n",
    "for qrels_train_file in qrels_train_file_list:\n",
    "    dataset = qrels_train_file.split(\"/\")[-1].split(\"_\")[2]\n",
    "    for old_qrels_file in old_qrels_file_list:\n",
    "        if dataset in old_qrels_file:\n",
    "            print(qrels_train_file,old_qrels_file)\n",
    "            # 复制文件\n",
    "            shutil.copy2(old_qrels_file, qrels_train_file)\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09d3281f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90ea603e-d64a-4e59-bbf9-f26959e9af20",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5e4c8e6-e372-41a5-a434-6b6fe04fa3b4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d5ccf655-56e5-475e-bf6f-59b1daa10967",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 根据 qid 对训练集进行去重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6836f2a-a913-4194-9677-06ef92883d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "Path_query_train = \"/home/user_name/U-MARVEL/data/M-BEIR/query/train/query_train\"\n",
    "# 任务映射\n",
    "dataset_to_val_data_file_middle_name_map = {\n",
    "            \"VisualNews\": [\"visualnews_task0\", \"visualnews_task3\"],\n",
    "            \"MSCOCO\": [\"mscoco_task0\", \"mscoco_task3\"],\n",
    "            \"Fashion200K\": [\"fashion200k_task0\", \"fashion200k_task3\"],\n",
    "            \"WebQA\": [\"webqa_task1\", \"webqa_task2\"],\n",
    "            \"EDIS\": [\"edis_task2\"],\n",
    "            \"NIGHTS\": [\"nights_task4\"],\n",
    "            \"OVEN\": [\"oven_task6\", \"oven_task8\"],\n",
    "            \"INFOSEEK\": [\"infoseek_task6\", \"infoseek_task8\"],\n",
    "            \"FashionIQ\": [\"fashioniq_task7\"],\n",
    "            \"CIRR\": [\"cirr_task7\"],\n",
    "        }\n",
    "\n",
    "# Mapping of dataset names to IDs\n",
    "DATASET_IDS = {\n",
    "    \"visualnews\": 0,     \"fashion200k\": 1,   \"webqa\": 2,      \"edis\": 3,\n",
    "    \"nights\": 4,         \"oven\": 5,          \"infoseek\": 6,\n",
    "    \"fashioniq\": 7,      \"cirr\": 8,          \"mscoco\": 9,\n",
    "}\n",
    "\n",
    "MBEIR_TASK = {\n",
    "    \"text -> image\": 0,      \"text -> text\": 1,             \"text -> image,text\": 2,\n",
    "    \"image -> text\": 3,      \"image -> image\": 4,           \"image -> text,image\": 5,\n",
    "    \"image,text -> text\": 6, \"image,text -> image\": 7,      \"image,text -> image,text\": 8,\n",
    "}\n",
    "# \"./data/M-BEIR/query/train/mbeir_train_visualnews_task3.jsonl\",\n",
    "query_train_file_list = []\n",
    "for dataset,datatset_task_name in dataset_to_val_data_file_middle_name_map.items():\n",
    "    for task_name in datatset_task_name:\n",
    "        query_train_file_list.append(os.path.join(Path_query_train, f\"mbeir_train_{task_name}.jsonl\"))\n",
    "print(\"query_train_file_list[0]:\", query_train_file_list[0])\n",
    "print(len(query_train_file_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e45fb98e-8dbc-4204-9dac-eb9a6e1bce33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对 query_train_file_list 按照 qid 进行去重, 重新进行保存\n",
    "num_1 = 0\n",
    "num_2 = 0\n",
    "for query_train_file in query_train_file_list:\n",
    "    qid_set = set()\n",
    "    data = defaultdict(dict)\n",
    "    \n",
    "    with open(query_train_file, 'r', encoding='utf-8') as f:\n",
    "        print(query_train_file)\n",
    "        for line in tqdm(f):\n",
    "            item = json.loads(line)\n",
    "            qid  = item['qid']\n",
    "            num_2+=1\n",
    "            if qid not in qid_set:\n",
    "                qid_set.add(qid)\n",
    "                data[qid] = item.copy()\n",
    "            else:\n",
    "                try: \n",
    "                    assert data[qid] == item\n",
    "                except:\n",
    "                    print(data[qid])\n",
    "                    print(item)\n",
    "    num_1 += len(data)\n",
    "    new_file = query_train_file.replace('.jsonl', '_dedup.jsonl')\n",
    "    with open(new_file, 'w', encoding='utf-8') as f:\n",
    "        for item in data.values():\n",
    "            f.write(json.dumps(item, ensure_ascii=False) + '\\n')\n",
    "print(f\"未去重 Total number of qid in query train list: {num_2}\")\n",
    "print(f\"去重 Total number of qid in query train list: {num_1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c4b5f7-1a2e-4d3a-b79e-804140574532",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对 query_train_file_list 按照 qid 进行去重, 重新进行保存\n",
    "num_1 = 0\n",
    "num_2 = 0\n",
    "for query_train_file in query_train_file_list:\n",
    "    new_file = query_train_file.replace('.jsonl', '_dedup.jsonl')\n",
    "    qid_set = set()\n",
    "    data = defaultdict(dict) \n",
    "    with open(new_file, 'r', encoding='utf-8') as f:\n",
    "        print(new_file)\n",
    "        for line in tqdm(f):\n",
    "            item = json.loads(line)\n",
    "            qid  = item['qid']\n",
    "            num_2+=1\n",
    "            if qid not in qid_set:\n",
    "                qid_set.add(qid)\n",
    "                data[qid] = item.copy()\n",
    "            else:\n",
    "                try: \n",
    "                    assert data[qid] == item\n",
    "                except:\n",
    "                    print(data[qid])\n",
    "                    print(item)\n",
    "    num_1 += len(data)\n",
    "print(f\"未去重 Total number of qid in query train list: {num_2}\")\n",
    "print(f\"去重 Total number of qid in query train list: {num_1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a716b0e-fe7a-4e20-bb52-20434f8c868d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "d50f8c58",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python(3.8.8)",
   "language": "python",
   "name": "env-3.8.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
