{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import itertools\n",
    "import math\n",
    "import os\n",
    "from shutil import copytree\n",
    "from typing import List, Tuple\n",
    "\n",
    "import PIL.Image\n",
    "import PIL.ImageDraw2\n",
    "import pandas as pd\n",
    "from PIL.ImageDraw2 import Pen\n",
    "from tqdm import tqdm"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "source_dir = \"~/Datasets/bongard_hoi_splitted\"\n",
    "target_dir = \"~/Projects/llm-avr-benchmarks/data/raw/bongard_hoi_splitted_mix\"\n",
    "target_labels_path = \"~/Projects/llm-avr-benchmarks/data/raw/bongard_hoi_mix_labels.csv\""
   ],
   "id": "6bac1f16b0226a47",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "print(\"Unique concepts\")\n",
    "for category in os.listdir(source_dir):\n",
    "    print(category, len(os.listdir(f\"{source_dir}/{category}\")) - 1)"
   ],
   "id": "11f9af79df633451",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "num_concepts_by_category = {\n",
    "    \"bongard_hoi_test_unseen_obj_unseen_act\": 16,\n",
    "    \"bongard_hoi_test_seen_obj_seen_act\": 36,  # out of 102\n",
    "    \"bongard_hoi_test_unseen_obj_seen_act\": 27,\n",
    "    \"bongard_hoi_test_seen_obj_unseen_act\": 21,\n",
    "}\n",
    "print(\"Total\", sum(num_concepts_by_category.values()))"
   ],
   "id": "19fe9f0468bcd041",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Combine labels",
   "id": "4c5accd1681b343a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "num_problems = 0\n",
    "dfs = []\n",
    "for category, num_concepts in tqdm(num_concepts_by_category.items()):\n",
    "    df = pd.read_csv(f\"{source_dir}/{category}/bongard_hoi_labels.csv\", index_col=\"uid\")\n",
    "    df = df.sort_values(by=\"uid\")\n",
    "    df = df.head(n=num_concepts)\n",
    "    df.index += num_problems\n",
    "    dfs.append(df)\n",
    "    num_problems += num_concepts\n",
    "df = pd.concat(dfs)\n",
    "df.to_csv(target_labels_path)\n",
    "df"
   ],
   "id": "f124dfd1647863a4",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Copy images from global dataset directory to local splitted directory",
   "id": "57ca8e7bd49efd5e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "total_problem_id = 1\n",
    "for category, num_concepts in tqdm(num_concepts_by_category.items()):\n",
    "    problem_ids = [\n",
    "        int(problem_id)\n",
    "        for problem_id in os.listdir(f\"{source_dir}/{category}\")\n",
    "        if \"csv\" not in problem_id\n",
    "    ]\n",
    "    for i, problem_id in enumerate(sorted(problem_ids)[:num_concepts]):\n",
    "        copytree(f\"{source_dir}/{category}/{problem_id}\", f\"{target_dir}/{total_problem_id}\")\n",
    "        total_problem_id += 1"
   ],
   "id": "27696769692c707a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Resize each image so that the bigger dimension (height or width) is at most 512px",
   "id": "f516ef66e1b07174"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "max_size = 512\n",
    "for problem_id in tqdm(sorted(os.listdir(target_dir))):\n",
    "    for side in os.listdir(f\"{target_dir}/{problem_id}\"):\n",
    "        for filename in os.listdir(f\"{target_dir}/{problem_id}/{side}\"):\n",
    "            filepath = f\"{target_dir}/{problem_id}/{side}/{filename}\"\n",
    "            try:\n",
    "                image = PIL.Image.open(filepath).convert(\"RGB\")\n",
    "\n",
    "                width, height = image.size[0], image.size[1]\n",
    "                if width <= max_size and height <= max_size:\n",
    "                    continue\n",
    "\n",
    "                if width > height:\n",
    "                    new_width = max_size\n",
    "                    new_height = int((height / width) * max_size)\n",
    "                else:\n",
    "                    new_width = int((width / height) * max_size)\n",
    "                    new_height = max_size\n",
    "\n",
    "                image = image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "                image.save(filepath)\n",
    "            except Exception as e:\n",
    "                print(filepath, e)"
   ],
   "id": "658e4f58e7e6da89",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Arrange into the whole matrix such that its area is minimized; Resize to at most 1024px per dimension",
   "id": "dadd096d62aa704b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "permutations = list(itertools.permutations(range(6), 6))\n",
    "\n",
    "\n",
    "def get_size(\n",
    "        sizes: List[Tuple[int, int]],\n",
    "        permutation: List[int]\n",
    ") -> Tuple[int, int]:\n",
    "    p = permutation\n",
    "    width = max([\n",
    "        sizes[p[0]][0] + sizes[p[1]][0],\n",
    "        sizes[p[2]][0] + sizes[p[3]][0],\n",
    "        sizes[p[4]][0] + sizes[p[5]][0],\n",
    "    ])\n",
    "    height = max([\n",
    "        sizes[p[0]][1] + sizes[p[2]][1] + sizes[p[4]][1],\n",
    "        sizes[p[1]][1] + sizes[p[3]][1] + sizes[p[5]][1],\n",
    "    ])\n",
    "    return width, height\n",
    "\n",
    "\n",
    "def get_permutations(\n",
    "        left_sizes: List[Tuple[int, int]],\n",
    "        right_sizes: List[Tuple[int, int]]\n",
    ") -> Tuple[List[int], List[int]]:\n",
    "    min_area = math.inf\n",
    "    best_permutations = None\n",
    "    for left_permutation in permutations:\n",
    "        for right_permutation in permutations:\n",
    "            left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "            right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "            width = left_width + right_width\n",
    "            height = max(left_height, right_height)\n",
    "            area = width * height\n",
    "            if area < min_area:\n",
    "                min_area = area\n",
    "                best_permutations = (left_permutation, right_permutation)\n",
    "    return best_permutations"
   ],
   "id": "4beaa8455a173c90",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def draw_side(\n",
    "        width: int,\n",
    "        height: int,\n",
    "        margin: int,\n",
    "        permutation: List[int],\n",
    "        sizes: List[Tuple[int, int]],\n",
    "        images: List[PIL.Image]\n",
    ") -> PIL.Image:\n",
    "    canvas = PIL.Image.new('RGB', (margin + width, 2 * margin + height))\n",
    "    y1, y2 = 0, 0\n",
    "    for i, p in enumerate(permutation):\n",
    "        image_width, image_height = sizes[p]\n",
    "        image = images[p]\n",
    "        x = 0 if i % 2 == 0 else width + margin - image_width\n",
    "        if i % 2 == 0:\n",
    "            y = y1\n",
    "            y1 += image_height + margin\n",
    "        else:\n",
    "            y = y2\n",
    "            y2 += image_height + margin\n",
    "        canvas.paste(image, (x, y))\n",
    "    return canvas"
   ],
   "id": "ad94ab0b1767a6cb",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def resize(image: PIL.Image, max_size: int) -> PIL.Image:\n",
    "    width, height = image.size[0], image.size[1]\n",
    "    if width > max_size or height > max_size:\n",
    "        if width > height:\n",
    "            new_width = max_size\n",
    "            new_height = int((height / width) * max_size)\n",
    "        else:\n",
    "            new_width = int((width / height) * max_size)\n",
    "            new_height = max_size\n",
    "        image = image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "    return image"
   ],
   "id": "8a5dd3e83057a3c2",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "margin = 10\n",
    "side_max_size = 512\n",
    "max_size = 1024\n",
    "for problem_id in tqdm(sorted(os.listdir(target_dir))):\n",
    "    left_images = [\n",
    "        PIL.Image.open(f\"{target_dir}/{problem_id}/left/{filename}\")\n",
    "        for filename in os.listdir(f\"{target_dir}/{problem_id}/left\")\n",
    "    ]\n",
    "\n",
    "    right_images = [\n",
    "        PIL.Image.open(f\"{target_dir}/{problem_id}/right/{filename}\")\n",
    "        for filename in os.listdir(f\"{target_dir}/{problem_id}/right\")\n",
    "    ]\n",
    "\n",
    "    left_sizes = [(image.size[0], image.size[1]) for image in left_images]\n",
    "    right_sizes = [(image.size[0], image.size[1]) for image in right_images]\n",
    "\n",
    "    left_permutation, right_permutation = get_permutations(left_sizes, right_sizes)\n",
    "    left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "    right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "\n",
    "    left_canvas = draw_side(left_width, left_height, margin, left_permutation, left_sizes, left_images)\n",
    "    right_canvas = draw_side(right_width, right_height, margin, right_permutation, right_sizes, right_images)\n",
    "\n",
    "    left_canvas = resize(left_canvas, side_max_size)\n",
    "    right_canvas = resize(right_canvas, side_max_size)\n",
    "\n",
    "    total_width = 4 * margin + left_width + right_width\n",
    "    total_height = 2 * margin + max(left_height, right_height)\n",
    "    canvas = PIL.Image.new('RGB', (total_width, total_height))\n",
    "\n",
    "    # Draw whole\n",
    "    for (side, images, sizes, permutation) in [\n",
    "        (\"left\", left_images, left_sizes, left_permutation),\n",
    "        (\"right\", right_images, right_sizes, right_permutation),\n",
    "    ]:\n",
    "        x = 0 if side == \"left\" else left_width\n",
    "        y1, y2 = 0, 0\n",
    "        for i, p in enumerate(permutation):\n",
    "            image_width, image_height = sizes[p]\n",
    "            image = images[p]\n",
    "\n",
    "            if i % 2 == 0:\n",
    "                y = y1\n",
    "                y1 += image_height + margin\n",
    "            else:\n",
    "                y = y2\n",
    "                y2 += image_height + margin\n",
    "\n",
    "            if side == \"left\":\n",
    "                x = 0 if i % 2 == 0 else left_width + margin - image_width\n",
    "            else:\n",
    "                x = left_width + 3 * margin if i % 2 == 0 else total_width - image_width\n",
    "\n",
    "            draw = PIL.ImageDraw2.Draw(canvas)\n",
    "            draw.line((left_width + 2 * margin, 0, left_width + 2 * margin, total_height), Pen(color=\"white\", width=5))\n",
    "\n",
    "            canvas.paste(image, (x, y))\n",
    "\n",
    "    if total_width > max_size or total_height > max_size:\n",
    "        if total_width > total_height:\n",
    "            new_width = max_size\n",
    "            new_height = int((total_height / total_width) * max_size)\n",
    "        else:\n",
    "            new_width = int((total_width / total_height) * max_size)\n",
    "            new_height = max_size\n",
    "        canvas = canvas.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "\n",
    "    left_canvas.save(f\"{target_dir}/{problem_id}/left.png\")\n",
    "    right_canvas.save(f\"{target_dir}/{problem_id}/right.png\")\n",
    "    canvas.save(f\"{target_dir}/{problem_id}/whole.png\")\n",
    "    # canvas.show()"
   ],
   "id": "6c6e2893dfecace7",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Arrange but with white background",
   "id": "6373668081260067"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "margin = 10\n",
    "max_size = 1024\n",
    "for problem_id in [26, 71, 77, 91, 93]:\n",
    "    left_images = [\n",
    "        PIL.Image.open(f\"{target_dir}/{problem_id}/left/{filename}\")\n",
    "        for filename in os.listdir(f\"{target_dir}/{problem_id}/left\")\n",
    "    ]\n",
    "\n",
    "    right_images = [\n",
    "        PIL.Image.open(f\"{target_dir}/{problem_id}/right/{filename}\")\n",
    "        for filename in os.listdir(f\"{target_dir}/{problem_id}/right\")\n",
    "    ]\n",
    "\n",
    "    left_sizes = [(image.size[0], image.size[1]) for image in left_images]\n",
    "    right_sizes = [(image.size[0], image.size[1]) for image in right_images]\n",
    "\n",
    "    left_permutation, right_permutation = get_permutations(left_sizes, right_sizes)\n",
    "    left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "    right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "\n",
    "    total_width = 4 * margin + left_width + right_width\n",
    "    total_height = 2 * margin + max(left_height, right_height)\n",
    "    canvas = PIL.Image.new('RGB', (total_width, total_height), color=\"white\")\n",
    "\n",
    "    # Draw whole\n",
    "    for (side, images, sizes, permutation) in [\n",
    "        (\"left\", left_images, left_sizes, left_permutation),\n",
    "        (\"right\", right_images, right_sizes, right_permutation),\n",
    "    ]:\n",
    "        x = 0 if side == \"left\" else left_width\n",
    "        y1, y2 = 0, 0\n",
    "        for i, p in enumerate(permutation):\n",
    "            image_width, image_height = sizes[p]\n",
    "            image = images[p]\n",
    "\n",
    "            if i % 2 == 0:\n",
    "                y = y1\n",
    "                y1 += image_height + margin\n",
    "            else:\n",
    "                y = y2\n",
    "                y2 += image_height + margin\n",
    "\n",
    "            if side == \"left\":\n",
    "                x = 0 if i % 2 == 0 else left_width + margin - image_width\n",
    "            else:\n",
    "                x = left_width + 3 * margin if i % 2 == 0 else total_width - image_width\n",
    "\n",
    "            draw = PIL.ImageDraw2.Draw(canvas)\n",
    "            draw.line((left_width + 2 * margin, 0, left_width + 2 * margin, total_height), Pen(color=\"black\", width=5))\n",
    "\n",
    "            canvas.paste(image, (x, y))\n",
    "\n",
    "    if total_width > max_size or total_height > max_size:\n",
    "        if total_width > total_height:\n",
    "            new_width = max_size\n",
    "            new_height = int((total_height / total_width) * max_size)\n",
    "        else:\n",
    "            new_width = int((total_width / total_height) * max_size)\n",
    "            new_height = max_size\n",
    "        canvas = canvas.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "\n",
    "    canvas.save(f\"{target_dir}/{problem_id}/whole-white.png\")\n",
    "    # canvas.show()"
   ],
   "id": "def09bedce6307f",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "",
   "id": "8bbaa3087b3b8a14",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
