{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e5a8e83a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "from copy import deepcopy\n",
    "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c1ac47fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_SAMPLE = 100\n",
    "DATA_ROOT = os.getcwd()\n",
    "CUBE_DIR = f'{DATA_ROOT}/cubes'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82f300c0",
   "metadata": {},
   "source": [
    "## identical cubes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "20ad3735",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(CUBE_DIR):\n",
    "    shutil.rmtree(CUBE_DIR) ######## reset\n",
    "os.makedirs(CUBE_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d34ef7c0",
   "metadata": {},
   "source": [
    "### pattern-representations and 3d-representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe47dc06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scene symmetry 49 22\n",
      "scene symmetry 54 22\n"
     ]
    }
   ],
   "source": [
    "SEED = 2025 * hash_str('sample')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# generate new random settings\n",
    "rep_pattern = []\n",
    "rep_3d_set_all = set()\n",
    "rep_3d_list_all = []\n",
    "while len(rep_pattern) < NUM_SAMPLE:\n",
    "    cube_idx = len(rep_pattern)\n",
    "    # random pattern\n",
    "    rand_orientation_idx = np.random.randint(0, 4, 6).tolist()\n",
    "    rand_color_idx = np.random.choice(COLOR_NAMES, 6, replace=True).tolist()\n",
    "    # rand_color_idx = np.random.choice(COLOR_NAMES, 6, replace=False).tolist()\n",
    "    rep_3d = get_3d_rep(rand_orientation_idx, rand_color_idx)\n",
    "\n",
    "    # check identification\n",
    "    if tuple(rep_3d.items()) in rep_3d_set_all:\n",
    "        print('not identical', cube_idx)\n",
    "        continue\n",
    "\n",
    "    # get 24 possible rotated 3d representation\n",
    "    rep_3d_set_cur = set()\n",
    "    scene_set = set()\n",
    "    for action in group_elements:\n",
    "        rep_3d_rotated = rotate_cube(rep_3d, action)\n",
    "        scene = {k: rep_3d_rotated[k] for k in ['top', 'front', 'right']}\n",
    "        rep_3d_set_cur.add(tuple(rep_3d_rotated.items()))\n",
    "        scene_set.add(tuple(scene.items()))\n",
    "    rep_3d_list_cur = [dict(v) for v in sorted(rep_3d_set_cur)]\n",
    "\n",
    "    num_reps = len(rep_3d_set_cur)\n",
    "    num_scenes = len(scene_set)\n",
    "\n",
    "    if num_scenes < 24: # make sure all 24 3d_reps are different\n",
    "        print('scene symmetry', cube_idx, num_scenes)\n",
    "\n",
    "    if num_reps < 24: # make sure all 24 3d_reps are different\n",
    "        print('self symmetry', cube_idx, num_reps)\n",
    "\n",
    "    # save info\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}'\n",
    "    if not os.path.exists(folder):\n",
    "        os.makedirs(folder)\n",
    "    with open(f'{folder}/rep_3d.json', 'w') as f:\n",
    "        json.dump(rep_3d_list_cur, f, indent=2)\n",
    "\n",
    "    # add cubes\n",
    "    rep_pattern.append((rand_orientation_idx, rand_color_idx))\n",
    "    rep_3d_set_all.update(rep_3d_set_cur)\n",
    "    rep_3d_list_all.append(rep_3d_list_cur)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be129773",
   "metadata": {},
   "source": [
    "### 3d reps --> cube images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6426f9bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 51.22it/s] \n"
     ]
    }
   ],
   "source": [
    "def process_cube(cube_idx):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    if not os.path.exists(folder):\n",
    "        return None\n",
    "    \n",
    "    with open(f'{folder}/rep_3d.json', 'r') as f:\n",
    "        rep_3d_list = json.load(f)\n",
    "    \n",
    "    cube_folder = f'{folder}/cube_positive'\n",
    "    os.makedirs(cube_folder, exist_ok=True)\n",
    "    \n",
    "    path_list = []\n",
    "    for j, rep_3d in enumerate(rep_3d_list):\n",
    "        save_path = f'{cube_folder}/cube_{j:02d}.png'\n",
    "        path_list.append(save_path)\n",
    "        plot_3d(rep_3d, save_path, show=False)\n",
    "    \n",
    "    save_path = f'{cube_folder}/cube_all.png'\n",
    "    plot_comp(save_path, path_list, (4,6))\n",
    "    \n",
    "    return\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
    "    futures = [executor.submit(process_cube, cube_idx) for cube_idx in range(NUM_SAMPLE)]\n",
    "    for future in tqdm(as_completed(futures), total=NUM_SAMPLE):\n",
    "        future.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b33f1d9c",
   "metadata": {},
   "source": [
    "### 3d reps --> 2d reps --> 2d images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e454e454",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:08<00:00, 11.65it/s]\n"
     ]
    }
   ],
   "source": [
    "def process_net(cube_idx):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    if not os.path.exists(folder):\n",
    "        return None\n",
    "    \n",
    "    with open(f'{folder}/rep_3d.json', 'r') as f:\n",
    "        rep_3d_list = json.load(f)\n",
    "    pos_folder = f'{folder}/net_positive'\n",
    "    os.makedirs(pos_folder, exist_ok=True)\n",
    "\n",
    "    pos_2d_dict = dict()\n",
    "    for rep_3d_idx, rep_3d in enumerate(rep_3d_list):\n",
    "        ori_idx, colors = from_3d_rep(rep_3d)\n",
    "        # 2d reps\n",
    "        for net_type in NET_TYPES:\n",
    "            rep_2d = get_2d_rep(ori_idx, colors, net_type)\n",
    "\n",
    "            image_file_name = f'net_{rep_3d_idx:02d}_{net_type}.png'\n",
    "            pos_2d_dict[image_file_name] = rep_2d\n",
    "            save_path = f'{pos_folder}/{image_file_name}'\n",
    "            plot_2d(rep_2d, save_path, show=False)\n",
    "    with open(f'{folder}/rep_2d.json', 'w') as f:\n",
    "        json.dump(pos_2d_dict, f, indent=4)\n",
    "    \n",
    "\n",
    "with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
    "    futures = [executor.submit(process_net, i) for i in range(NUM_SAMPLE)]\n",
    "    for future in tqdm(as_completed(futures), total=NUM_SAMPLE):\n",
    "        future.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d1bad1f",
   "metadata": {},
   "source": [
    "## negative examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "850e7ad3",
   "metadata": {},
   "source": [
    "### 3d reps --> negative 3d reps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3aff6468",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 383.48it/s]\n"
     ]
    }
   ],
   "source": [
    "SEED = 2025 * hash_str('3d_negative_sampling')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# generate\n",
    "visible_faces = ['front', 'top', 'right']\n",
    "invisible_faces = ['back', 'bottom', 'left']\n",
    "for cube_idx in tqdm(range(NUM_SAMPLE)):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    with open(f'{folder}/rep_3d.json', 'r') as f: # follow the same order as rep_3d.json\n",
    "        rep_3d_list = json.load(f)\n",
    "\n",
    "    rep_3d_neg_set = set()\n",
    "    rep_3d_neg_dict = dict()\n",
    "    for rep_3d_idx in range(24):\n",
    "        rep_3d_cur = deepcopy(rep_3d_list[rep_3d_idx]) # dict(face: (ori, color))\n",
    "        for face_tmp in invisible_faces:\n",
    "            rep_3d_cur.pop(face_tmp)\n",
    "        for face_tmp in visible_faces:\n",
    "            rep_3d_cur[face_tmp] = tuple(rep_3d_cur[face_tmp])\n",
    "\n",
    "        similarity = '333'\n",
    "        while similarity[0] == '3': # if mutated rep_3d refers to the same cube, try again\n",
    "            face_mutate = np.random.choice(visible_faces).tolist()\n",
    "            # randomly mutate color or orientation\n",
    "            if np.random.rand() < 0.5: # mutate color\n",
    "                # color on random invisible face\n",
    "                face_rand = np.random.choice(invisible_faces).tolist()\n",
    "                color_neg = rep_3d_list[rep_3d_idx][face_rand][1]\n",
    "                # replace color\n",
    "                rep_3d_tmp = deepcopy(rep_3d_cur)\n",
    "                rep_3d_tmp[face_mutate] = (rep_3d_tmp[face_mutate][0], color_neg)\n",
    "            else: # mutate orientation\n",
    "                ori_set = set(FACE_ADJACENCY[face_mutate])\n",
    "                ori_cur = rep_3d_cur[face_mutate][0]\n",
    "                ori_set.remove(ori_cur)\n",
    "                ori_neg = np.random.choice(list(ori_set)).tolist()\n",
    "                # replace orientation\n",
    "                rep_3d_tmp = deepcopy(rep_3d_cur)\n",
    "                rep_3d_tmp[face_mutate] = (ori_neg, rep_3d_tmp[face_mutate][1])\n",
    "            # check similarity and existence\n",
    "            similarity = get_similarity(rep_3d_tmp, rep_3d_list, sim_type='scene')\n",
    "            similarity = ''.join([str(x) for x in similarity])\n",
    "            tmp_tuple = tuple(rep_3d_tmp.items())\n",
    "            if tmp_tuple in rep_3d_neg_set: # avoid neg-sample duplication\n",
    "                continue\n",
    "            else:\n",
    "                rep_3d_neg_set.add(tmp_tuple)\n",
    "        rep_3d_neg_dict[f'{rep_3d_idx:02d}_{face_mutate}_{similarity}'] = rep_3d_tmp\n",
    "\n",
    "    with open(f'{folder}/rep_3d_neg.json', 'w') as f:\n",
    "        json.dump(rep_3d_neg_dict, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28697ffb",
   "metadata": {},
   "source": [
    "### negative 3d reps --> negative cube images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "05509456",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:01<00:00, 93.73it/s]\n"
     ]
    }
   ],
   "source": [
    "def process_cube_neg(cube_idx):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    if not os.path.exists(folder):\n",
    "        return None\n",
    "    # make folder\n",
    "    folder_neg = f'{folder}/cube_negative'\n",
    "    if os.path.exists(folder_neg): # empty the folder if it exists\n",
    "        shutil.rmtree(folder_neg)\n",
    "    os.makedirs(folder_neg)\n",
    "    # load\n",
    "    with open(f'{folder}/rep_3d_neg.json', 'r') as f:\n",
    "        rep_3d_neg_dict = json.load(f)\n",
    "    for key, rep_3d in rep_3d_neg_dict.items():\n",
    "        save_path = f'{folder_neg}/{key}.png'\n",
    "        plot_3d(rep_3d, save_path, show=False)\n",
    "    return\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
    "    futures = [executor.submit(process_cube_neg, cube_idx) for cube_idx in range(NUM_SAMPLE)]\n",
    "    for future in tqdm(as_completed(futures), total=NUM_SAMPLE):\n",
    "        future.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d5c3014",
   "metadata": {},
   "source": [
    "### 2d reps --> negative 2d reps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "271072ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 312.80it/s]\n"
     ]
    }
   ],
   "source": [
    "# negative 2d-net \n",
    "SEED = 2025 * hash_str('2d_negative_sampling')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# generate\n",
    "for cube_idx in tqdm(range(NUM_SAMPLE)):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    with open(f'{folder}/rep_3d.json', 'r') as f: # follow the same order as rep_3d.json\n",
    "        rep_3d_list = json.load(f)\n",
    "\n",
    "    neg_dict = dict()\n",
    "    for view_idx in range(24):\n",
    "        rep_3d_cur = deepcopy(rep_3d_list[view_idx]) # dict(face: (ori, color))\n",
    "\n",
    "        similarity = '666'\n",
    "        while similarity[0] == '6': # choose a random face to mutate the orientation\n",
    "            face = np.random.choice(FACE_NAMES).tolist()\n",
    "            ori_set = set(FACE_ADJACENCY[face])\n",
    "            ori_cur = rep_3d_cur[face][0]\n",
    "            ori_set.remove(ori_cur)\n",
    "            ori_neg = np.random.choice(list(ori_set)).tolist()\n",
    "            # replace orientation\n",
    "            rep_3d_tmp = deepcopy(rep_3d_cur)\n",
    "            rep_3d_tmp[face] = (ori_neg, rep_3d_tmp[face][1])\n",
    "            # check similarity\n",
    "            similarity = get_similarity(rep_3d_tmp, rep_3d_list, sim_type='rep')\n",
    "            similarity = ''.join([str(x) for x in similarity])\n",
    "        # choose a random net type to unfold the cube\n",
    "        net_type = np.random.choice(NET_TYPES).tolist()\n",
    "        ori_idx, colors = from_3d_rep(rep_3d_tmp)\n",
    "        rep_2d = get_2d_rep(ori_idx, colors, net_type)\n",
    "        neg_dict[f'{view_idx:02d}_{similarity}_{net_type}'] = rep_2d\n",
    "    # save\n",
    "    with open(f'{folder}/rep_2d_neg.json', 'w') as f:\n",
    "        json.dump(neg_dict, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d321d72b",
   "metadata": {},
   "source": [
    "### negative 2d reps --> negative 2d images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "82b14350",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 104.30it/s]\n"
     ]
    }
   ],
   "source": [
    "def process_net_neg(cube_idx):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    if not os.path.exists(folder):\n",
    "        return None\n",
    "    \n",
    "    with open(f'{folder}/rep_2d_neg.json', 'r') as f:\n",
    "        rep_2d_neg = json.load(f)\n",
    "    \n",
    "    neg_folder = f'{folder}/net_negative'\n",
    "    os.makedirs(neg_folder, exist_ok=True)\n",
    "\n",
    "    for key, rep_2d in rep_2d_neg.items():\n",
    "        save_path = f'{neg_folder}/{key}.png'\n",
    "        plot_2d(rep_2d, save_path, show=False)\n",
    "    return\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
    "    futures = [executor.submit(process_net_neg, cube_idx) for cube_idx in range(NUM_SAMPLE)]\n",
    "    for future in tqdm(as_completed(futures), total=NUM_SAMPLE):\n",
    "        future.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffd8ade4",
   "metadata": {},
   "source": [
    "## questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c67f93a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from questions import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ad9989a",
   "metadata": {},
   "source": [
    "### perception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "cafb39cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 24 rep_3d * 3 faces * 6 questions = 432 （2400 image samples for each question pair)\n",
    "# samples: [rep_3d_idx, question_type, question]\n",
    "def get_perception_questions(rep_3d_list): \n",
    "    samples = []\n",
    "    for rep_3d_idx in range(24):\n",
    "        rep_3d = rep_3d_list[rep_3d_idx]\n",
    "        # color\n",
    "        visible_faces = ['front', 'top', 'right']\n",
    "        color_set = set([rep_3d[face][1] for face in visible_faces])\n",
    "        for face in visible_faces:\n",
    "            if len(color_set) == 1: # all the same color\n",
    "                break\n",
    "            color_pos = rep_3d[face][1]\n",
    "            color_neg = np.random.choice(list(color_set - set([color_pos]))).tolist() #  3-1=2 at most\n",
    "            samples.append([rep_3d_idx, f'color_{face}_pos', question_perception_cube_color(face, color_pos)])\n",
    "            samples.append([rep_3d_idx, f'color_{face}_neg', question_perception_cube_color(face, color_neg)])\n",
    "        # orientation\n",
    "        for face in visible_faces:\n",
    "            ori_pos = rep_3d[face][0]\n",
    "            ori_set = set(FACE_ADJACENCY[face])\n",
    "            ori_neg = np.random.choice(list(ori_set - set([ori_pos]))).tolist() #  4-1=3\n",
    "            samples.append([rep_3d_idx, f'ori_{face}_pos', question_perception_cube_orientation(face, ori_pos)])\n",
    "            samples.append([rep_3d_idx, f'ori_{face}_neg', question_perception_cube_orientation(face, ori_neg)])\n",
    "        # Euclidean direction\n",
    "            samples.append([rep_3d_idx, f'Euc_{face}_pos', question_perception_cube_Euclidean(face, ori_pos)])\n",
    "            samples.append([rep_3d_idx, f'Euc_{face}_neg', question_perception_cube_Euclidean(face, ori_neg)])\n",
    "    return samples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "65a25558",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025 * hash_str('perception')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "for cube_idx in range(NUM_SAMPLE):\n",
    "    cur_dir = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    rep_3d_list = json.load(open(f'{cur_dir}/rep_3d.json', 'r'))\n",
    "    samples = get_perception_questions(rep_3d_list)\n",
    "    with open(f'{cur_dir}/perception.json', 'w') as f:\n",
    "        json.dump(samples, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b26ac24",
   "metadata": {},
   "source": [
    "### operation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8f322994",
   "metadata": {},
   "outputs": [],
   "source": [
    "FACE_VISIBLE = ['front', 'top', 'right']\n",
    "FACE_ADJACENCY = {\n",
    "    'front': ['top', 'right', 'bottom', 'left'],\n",
    "    'top': ['back', 'right', 'front', 'left'],\n",
    "    'back': ['bottom', 'right', 'top', 'left'],\n",
    "    'bottom': ['front', 'right', 'back', 'left'],\n",
    "    'left': ['top', 'front', 'bottom', 'back'],\n",
    "    'right': ['top', 'back', 'bottom', 'front'],\n",
    "}\n",
    "axis_np = {\n",
    "    'front': np.array([0, 0, 1]),\n",
    "    'back': np.array([0, 0, -1]),\n",
    "    'left': np.array([-1, 0, 0]),\n",
    "    'right': np.array([1, 0, 0]),\n",
    "    'top': np.array([0, 1, 0]),\n",
    "    'bottom': np.array([0, -1, 0]),\n",
    "}\n",
    "def from_np(np_array, original):\n",
    "    x, y, z = np_array\n",
    "    if x == 1:\n",
    "        return 'right'\n",
    "    elif x == -1:\n",
    "        return 'left'\n",
    "    elif y == 1:\n",
    "        return 'top'\n",
    "    elif y == -1:\n",
    "        return 'bottom'\n",
    "    elif z == 1:\n",
    "        return 'front'\n",
    "    elif z == -1:\n",
    "        return 'back'\n",
    "    elif x == 0 and y == 0 and z == 0:\n",
    "        return original\n",
    "    else:\n",
    "        raise ValueError(f'Invalid np_array: {np_array}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6dac1c53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 24 rep_3d * 3 faces * 4 questions = 288 （2400 image samples for each question pair)\n",
    "# samples: [rep_3d_idx, question_type, question]\n",
    "def get_operation_questions(rep_3d_list):\n",
    "    samples = []\n",
    "    for rep_3d_idx in range(24):\n",
    "        rep_3d = rep_3d_list[rep_3d_idx]\n",
    "        # operation\n",
    "        for face in FACE_VISIBLE: ### traverse 3 faces for operation\n",
    "            init_arrow_ori = rep_3d[face][0] # arrow orientation (in FACE_NAMES)\n",
    "            init_arrow_np = axis_np[init_arrow_ori] \n",
    "            init_face_np = axis_np[face]\n",
    "\n",
    "            rot_face = np.random.choice(FACE_VISIBLE) # random rotation face\n",
    "            degree = np.random.choice([90, 270]) # random rotation degree\n",
    "            rot_axis = axis_np[rot_face] # rotation axis direction\n",
    "            if degree == 90:\n",
    "                final_arrow_np = np.cross(rot_axis, init_arrow_np)\n",
    "                final_face_np = np.cross(rot_axis, init_face_np)\n",
    "            elif degree == 270:\n",
    "                final_arrow_np = np.cross(-rot_axis, init_arrow_np)\n",
    "                final_face_np = np.cross(-rot_axis, init_face_np)\n",
    "            final_arrow = from_np(final_arrow_np, init_arrow_ori) # arrow orientation (in FACE_NAMES)\n",
    "            final_face = from_np(final_face_np, face) # face (in FACE_NAMES)\n",
    "            \n",
    "            ori_set = set(FACE_ADJACENCY[final_face])\n",
    "            final_arrow_neg = np.random.choice(list(ori_set - {final_arrow})).tolist() # 3 possible negative choices\n",
    "\n",
    "            # natural description\n",
    "            question_pos = question_operation_cube_rotation(degree, rot_face, face, final_arrow)\n",
    "            question_neg = question_operation_cube_rotation(degree, rot_face, face, final_arrow_neg)\n",
    "            samples.append([rep_3d_idx, f'ori_{face}_{rot_face}{degree}_pos', question_pos])\n",
    "            samples.append([rep_3d_idx, f'ori_{face}_{rot_face}{degree}_neg', question_neg])\n",
    "\n",
    "            # Euclidean description\n",
    "            question_pos = question_operation_cube_Euclidean(degree, rot_face, face, final_arrow)\n",
    "            question_neg = question_operation_cube_Euclidean(degree, rot_face, face, final_arrow_neg)\n",
    "            samples.append([rep_3d_idx, f'Euc_{face}_{rot_face}{degree}_pos', question_pos])\n",
    "            samples.append([rep_3d_idx, f'Euc_{face}_{rot_face}{degree}_neg', question_neg])\n",
    "    return samples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4566a439",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025 * hash_str('operation')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "for cube_idx in range(NUM_SAMPLE):\n",
    "    cur_dir = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    rep_3d_list = json.load(open(f'{cur_dir}/rep_3d.json', 'r'))\n",
    "    samples = get_operation_questions(rep_3d_list)\n",
    "    with open(f'{cur_dir}/operation.json', 'w') as f:\n",
    "        json.dump(samples, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccacbb70",
   "metadata": {},
   "source": [
    "### folding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "741080b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_folding_samples(cur_dir):\n",
    "    samples = []\n",
    "    for file in sorted(os.listdir(f'{cur_dir}/cube_negative')):\n",
    "        neg_cube_relative_path = f'cube_negative/{file}'\n",
    "\n",
    "        rep_3d_idx = int(file.split('_')[0])\n",
    "        pos_cube_relative_path = f'cube_positive/cube_{rep_3d_idx:02d}.png'\n",
    "\n",
    "        rep_3d_idx_ref = np.random.choice(list(range(24))).tolist()\n",
    "        rep_2d_type = np.random.choice(NET_TYPES).tolist()\n",
    "        net_relative_path = f'net_positive/net_{rep_3d_idx_ref:02d}_{rep_2d_type}.png'\n",
    "\n",
    "        assert os.path.exists(f'{cur_dir}/{neg_cube_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{pos_cube_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{net_relative_path}')\n",
    "\n",
    "        samples.append([f'{rep_3d_idx:02d}_pos', net_relative_path, pos_cube_relative_path])\n",
    "        samples.append([f'{rep_3d_idx:02d}_neg', net_relative_path, neg_cube_relative_path])\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c8d97037",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025 * hash_str('folding')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "for cube_idx in range(NUM_SAMPLE):\n",
    "    cur_dir = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    samples = get_folding_samples(cur_dir)\n",
    "    with open(f'{cur_dir}/folding.json', 'w') as f:\n",
    "        json.dump(samples, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1514b639",
   "metadata": {},
   "source": [
    "### matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "805f0e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_matching_samples(cur_dir):\n",
    "    samples = []\n",
    "    for file in sorted(os.listdir(f'{cur_dir}/net_negative')):\n",
    "        # paired positive and negative net images\n",
    "        neg_relative_path = f'net_negative/{file}'\n",
    "\n",
    "        rep_3d_idx, _, rep_2d_type = file.split('.')[0].split('_')\n",
    "        pos_relative_path = f'net_positive/net_{rep_3d_idx}_{rep_2d_type}.png'\n",
    "\n",
    "        # random reference net image\n",
    "        rep_3d_idx_ref = np.random.choice(list(range(24))).tolist()\n",
    "        rep_2d_type_ref = np.random.choice(NET_TYPES).tolist()\n",
    "        ref_relative_path = f'net_positive/net_{rep_3d_idx_ref:02d}_{rep_2d_type_ref}.png'\n",
    "\n",
    "        assert os.path.exists(f'{cur_dir}/{neg_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{pos_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{ref_relative_path}')\n",
    "\n",
    "        samples.append([f'{rep_3d_idx}_pos', ref_relative_path, pos_relative_path])\n",
    "        samples.append([f'{rep_3d_idx}_neg', ref_relative_path, neg_relative_path])\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "22ab3a47",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025 * hash_str('matching')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "for cube_idx in range(NUM_SAMPLE):\n",
    "    cur_dir = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    samples = get_matching_samples(cur_dir)\n",
    "    with open(f'{cur_dir}/matching.json', 'w') as f:\n",
    "        json.dump(samples, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904b05e6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "750bb677",
   "metadata": {},
   "source": [
    "## human test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f8758970",
   "metadata": {},
   "outputs": [],
   "source": [
    "HUMAN_DIR = f'{DATA_ROOT}/human'\n",
    "if not os.path.exists(HUMAN_DIR):\n",
    "    os.makedirs(HUMAN_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1331809",
   "metadata": {},
   "source": [
    "### perception"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f7b9fb8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_human_test_samples(seed_text, task, save_root, num_cube = 10, num_sample_per_cube = 6):\n",
    "    # seed\n",
    "    SEED = 2025 * hash_str(seed_text)\n",
    "    np.random.seed(SEED)\n",
    "    # generate\n",
    "    human_test_samples = [] # (cube_idx, rep_3d_idx, question_key, question_pos, question_neg)\n",
    "    for cube_idx in range(num_cube):\n",
    "        # load\n",
    "        with open(f'{CUBE_DIR}/{cube_idx:05d}/{task}.json', 'r') as f:\n",
    "            samples = json.load(f) # (rep_3d_idx, question_type, question))\n",
    "        sample_dict = dict()\n",
    "        for rep_3d_idx in range(24):\n",
    "            sample_dict[rep_3d_idx] = dict()\n",
    "        for rep_3d_idx, question_type, question in samples:\n",
    "            sample_dict[rep_3d_idx][question_type] = question\n",
    "        # rand\n",
    "        random_indices = np.random.permutation(range(len(samples)))[:num_sample_per_cube] ## the only random part\n",
    "        # human test samples\n",
    "        for rand_idx in random_indices:\n",
    "            rep_3d_idx, question_type, question = samples[rand_idx]\n",
    "            human_test_samples.append((cube_idx, rep_3d_idx, question_type, question))\n",
    "\n",
    "    # save\n",
    "    for idx, human_sample in enumerate(human_test_samples):\n",
    "        save_dir = f'{save_root}/{idx:03d}'\n",
    "        if not os.path.exists(save_dir):\n",
    "            os.makedirs(save_dir)\n",
    "\n",
    "        cube_idx, rep_3d_idx, question_type, question = human_sample\n",
    "        image_path = f'{CUBE_DIR}/{cube_idx:05d}/cube_positive/cube_{rep_3d_idx:02d}.png'\n",
    "        # copy image to save_dir\n",
    "        shutil.copy(image_path, f'{save_dir}/{question_type}.png')\n",
    "        # write question to file\n",
    "        sample_info = dict()\n",
    "        sample_info['cube_idx'] = cube_idx\n",
    "        sample_info['image_path'] = image_path\n",
    "        sample_info['question'] = question\n",
    "\n",
    "        with open(f'{save_dir}/question.json', 'w') as f:\n",
    "            json.dump(sample_info, f, indent=4)\n",
    "    return human_test_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b1e7b5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset\n",
    "perception_dir = f'{HUMAN_DIR}/perception'\n",
    "if os.path.exists(perception_dir):\n",
    "    shutil.rmtree(perception_dir)\n",
    "os.makedirs(perception_dir)\n",
    "\n",
    "human_test_samples = generate_human_test_samples('human_test_perception', 'perception', perception_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "7f32af94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Euc_right': 11,\n",
       "         'Euc_front': 8,\n",
       "         'color_right': 7,\n",
       "         'Euc_top': 7,\n",
       "         'color_front': 7,\n",
       "         'ori_right': 6,\n",
       "         'ori_top': 6,\n",
       "         'ori_front': 5,\n",
       "         'color_top': 3})"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "question_types = [line[2] for line in human_test_samples]\n",
    "question_count_fine_grained = Counter('_'.join(name.split('_')[:2]) for name in question_types)\n",
    "question_count_fine_grained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "018b87f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Euc': 26, 'ori': 17, 'color': 17})"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_count_coarse = Counter(name.split('_')[0] for name in question_types)\n",
    "question_count_coarse"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5fdda4e",
   "metadata": {},
   "source": [
    "### operation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9b7a9135",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset\n",
    "operation_dir = f'{HUMAN_DIR}/operation'\n",
    "if os.path.exists(operation_dir):\n",
    "    shutil.rmtree(operation_dir)\n",
    "os.makedirs(operation_dir)\n",
    "\n",
    "human_test_samples = generate_human_test_samples('human_test_operation', 'operation', operation_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "218f194c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Euc_front': 13,\n",
       "         'ori_top': 12,\n",
       "         'ori_front': 12,\n",
       "         'Euc_right': 10,\n",
       "         'Euc_top': 8,\n",
       "         'ori_right': 5})"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "question_types = [line[2] for line in human_test_samples]\n",
    "question_count_fine_grained = Counter('_'.join(name.split('_')[:2]) for name in question_types)\n",
    "question_count_fine_grained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "cceb4342",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'Euc': 31, 'ori': 29})"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question_count_coarse = Counter(name.split('_')[0] for name in question_types)\n",
    "question_count_coarse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "4ae7223e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'neg': 30, 'pos': 30})"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pos-neg balance\n",
    "question_types = [line[2] for line in human_test_samples]\n",
    "question_count = Counter('_'.join(name.split('_')[-1:]) for name in question_types)\n",
    "question_count"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1673f4e",
   "metadata": {},
   "source": [
    "### folding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2ccc7e5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "with open(f'{DATA_ROOT}/question_dict.json', 'r') as f:\n",
    "    question_dict = json.load(f)\n",
    "\n",
    "def generate_human_test_samples_two_images(seed_text, task, save_root, num_cube = 10, num_sample_per_cube = 6):\n",
    "    # seed\n",
    "    SEED = 2025 * hash_str(seed_text)\n",
    "    np.random.seed(SEED)\n",
    "    # generate\n",
    "    human_test_samples = [] # (cube_idx, question_key, image_1, image_2, question)\n",
    "    for cube_idx in range(num_cube):\n",
    "        with open(f'{CUBE_DIR}/{cube_idx:05d}/{task}.json', 'r') as f:\n",
    "            samples = json.load(f) # (question_key, image_1, image_2))\n",
    "        random_indices = np.random.permutation(range(len(samples)))[:num_sample_per_cube]\n",
    "        for j in random_indices:\n",
    "            question_key, image_1, image_2 = samples[j]\n",
    "            human_sample = (cube_idx, question_key, image_1, image_2, question_dict[f'question_{task}'])\n",
    "            human_test_samples.append(human_sample)\n",
    "\n",
    "    # save\n",
    "    for idx, sample in enumerate(human_test_samples):\n",
    "        save_dir = f'{save_root}/{idx:03d}'\n",
    "        if not os.path.exists(save_dir):\n",
    "            os.makedirs(save_dir)\n",
    "\n",
    "        cube_idx, question_key, image_1, image_2, question = sample\n",
    "        image_dir = f'{CUBE_DIR}/{cube_idx:05d}'\n",
    "        image_path_1 = f'{image_dir}/{image_1}'\n",
    "        image_path_2 = f'{image_dir}/{image_2}'\n",
    "        # copy image to save_dir\n",
    "        shutil.copy(image_path_1, f'{save_dir}/{question_key}_image_1.png')\n",
    "        shutil.copy(image_path_2, f'{save_dir}/{question_key}_image_2.png')\n",
    "        # write question to file\n",
    "        with open(f'{save_dir}/question.txt', 'w') as f:\n",
    "            json.dump(sample, f, indent=4)\n",
    "    return human_test_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "4c78fd08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset\n",
    "folding_dir = f'{HUMAN_DIR}/folding'\n",
    "if os.path.exists(folding_dir):\n",
    "    shutil.rmtree(folding_dir)\n",
    "os.makedirs(folding_dir)\n",
    "human_test_samples = generate_human_test_samples_two_images('human_test_folding', 'folding', folding_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "4ff65fdc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'pos': 35, 'neg': 25})"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pos-neg balance\n",
    "question_types = [line[1] for line in human_test_samples]\n",
    "question_count = Counter('_'.join(name.split('_')[-1:]) for name in question_types)\n",
    "question_count"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a1ff62",
   "metadata": {},
   "source": [
    "### matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "b77b7a43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# reset\n",
    "matching_dir = f'{HUMAN_DIR}/matching'\n",
    "if os.path.exists(matching_dir):\n",
    "    shutil.rmtree(matching_dir)\n",
    "os.makedirs(matching_dir)\n",
    "\n",
    "human_test_samples = generate_human_test_samples_two_images('human_test_matching', 'matching', matching_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "73b6c8ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'neg': 32, 'pos': 28})"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pos-neg balance\n",
    "question_types = [line[1] for line in human_test_samples]\n",
    "question_count = Counter('_'.join(name.split('_')[-1:]) for name in question_types)\n",
    "question_count"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fff26288",
   "metadata": {},
   "source": [
    "## ablation_folding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b8aa8a2",
   "metadata": {},
   "source": [
    "### cube easy negative (3d_reps --> easy negative 3d reps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "79cd99b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 540.48it/s]\n"
     ]
    }
   ],
   "source": [
    "SEED = 2025 * hash_str('cube_easy_negative_sampling')\n",
    "np.random.seed(SEED)\n",
    "\n",
    "# generate\n",
    "visible_faces = ['front', 'top', 'right']\n",
    "invisible_faces = ['back', 'bottom', 'left']\n",
    "for cube_idx in tqdm(range(NUM_SAMPLE)):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    with open(f'{folder}/rep_3d.json', 'r') as f: # follow the same order as rep_3d.json\n",
    "        rep_3d_list = json.load(f)\n",
    "\n",
    "    rep_3d_neg_set = set()\n",
    "    rep_3d_neg_dict = dict()\n",
    "    for rep_3d_idx in range(24):\n",
    "        rep_3d_cur = deepcopy(rep_3d_list[rep_3d_idx]) # dict(face: (ori, color))\n",
    "        for face_tmp in invisible_faces:\n",
    "            rep_3d_cur.pop(face_tmp)\n",
    "        for face_tmp in visible_faces:\n",
    "            rep_3d_cur[face_tmp] = tuple(rep_3d_cur[face_tmp])\n",
    "\n",
    "        similarity = '333'\n",
    "        while similarity[0] == '3': # if mutated rep_3d refers to the same cube, try again\n",
    "            face_mutate = np.random.choice(visible_faces).tolist()\n",
    "            # mutate color to another color that not of the cube\n",
    "            color_set = set(COLOR_NAMES)\n",
    "            color_set_cube = set([rep_3d_list[rep_3d_idx][face_tmp][1] for face_tmp in FACE_NAMES])\n",
    "            new_color_set = color_set - color_set_cube\n",
    "            color_neg = np.random.choice(list(new_color_set)).tolist()\n",
    "            # replace color\n",
    "            rep_3d_tmp = deepcopy(rep_3d_cur)\n",
    "            rep_3d_tmp[face_mutate] = (rep_3d_tmp[face_mutate][0], color_neg)\n",
    "                \n",
    "            # check similarity and existence\n",
    "            similarity = get_similarity(rep_3d_tmp, rep_3d_list, sim_type='scene')\n",
    "            similarity = ''.join([str(x) for x in similarity])\n",
    "            tmp_tuple = tuple(rep_3d_tmp.items())\n",
    "            if tmp_tuple in rep_3d_neg_set: # avoid neg-sample duplication\n",
    "                continue\n",
    "            else:\n",
    "                rep_3d_neg_set.add(tmp_tuple)\n",
    "        rep_3d_neg_dict[f'{rep_3d_idx:02d}_{face_mutate}_{similarity}'] = rep_3d_tmp\n",
    "\n",
    "    with open(f'{folder}/rep_3d_neg_easy.json', 'w') as f:\n",
    "        json.dump(rep_3d_neg_dict, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6c3a99c",
   "metadata": {},
   "source": [
    "### easy negative 3d reps --> images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "b1210547",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 105.34it/s]\n"
     ]
    }
   ],
   "source": [
    "def process_cube_neg(cube_idx):\n",
    "    folder = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    if not os.path.exists(folder):\n",
    "        return None\n",
    "    # make folder\n",
    "    folder_neg = f'{folder}/cube_negative_easy'\n",
    "    if os.path.exists(folder_neg): # empty the folder if it exists\n",
    "        shutil.rmtree(folder_neg)\n",
    "    os.makedirs(folder_neg)\n",
    "    # load\n",
    "    with open(f'{folder}/rep_3d_neg_easy.json', 'r') as f:\n",
    "        rep_3d_neg_dict = json.load(f)\n",
    "    for key, rep_3d in rep_3d_neg_dict.items():\n",
    "        save_path = f'{folder_neg}/{key}.png'\n",
    "        plot_3d(rep_3d, save_path, show=False)\n",
    "    return\n",
    "\n",
    "with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:\n",
    "    futures = [executor.submit(process_cube_neg, cube_idx) for cube_idx in range(NUM_SAMPLE)]\n",
    "    for future in tqdm(as_completed(futures), total=NUM_SAMPLE):\n",
    "        future.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56d230dd",
   "metadata": {},
   "source": [
    "### ablasion questions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "58ad663d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_folding_samples(cur_dir):\n",
    "    samples = []\n",
    "    for file in sorted(os.listdir(f'{cur_dir}/cube_negative_easy')): ##### change to cube_negative_easy\n",
    "        neg_cube_relative_path = f'cube_negative_easy/{file}' ##### change to cube_negative_easy\n",
    "\n",
    "        rep_3d_idx = int(file.split('_')[0])\n",
    "        pos_cube_relative_path = f'cube_positive/cube_{rep_3d_idx:02d}.png'\n",
    "\n",
    "        rep_3d_idx_ref = np.random.choice(list(range(24))).tolist()\n",
    "        rep_2d_type = np.random.choice(NET_TYPES).tolist()\n",
    "        net_relative_path = f'net_positive/net_{rep_3d_idx_ref:02d}_{rep_2d_type}.png'\n",
    "\n",
    "        assert os.path.exists(f'{cur_dir}/{neg_cube_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{pos_cube_relative_path}')\n",
    "        assert os.path.exists(f'{cur_dir}/{net_relative_path}')\n",
    "\n",
    "        samples.append([f'{rep_3d_idx:02d}_pos', net_relative_path, pos_cube_relative_path])\n",
    "        samples.append([f'{rep_3d_idx:02d}_neg', net_relative_path, neg_cube_relative_path])\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "46cd4636",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2025 * hash_str('ablation_folding') \n",
    "np.random.seed(SEED)\n",
    "\n",
    "for cube_idx in range(NUM_SAMPLE):\n",
    "    cur_dir = f'{CUBE_DIR}/{cube_idx:05d}/'\n",
    "    samples = get_folding_samples(cur_dir)\n",
    "    with open(f'{cur_dir}/ablation_folding.json', 'w') as f: ##### change to ablation_folding.json\n",
    "        json.dump(samples, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15530a6a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
