{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import itertools\n",
    "import math\n",
    "import os\n",
    "from shutil import copyfile\n",
    "from typing import List, Tuple\n",
    "\n",
    "import PIL.Image\n",
    "import PIL.ImageDraw2\n",
    "from PIL.ImageDraw2 import Pen\n",
    "from tqdm import tqdm\n",
    "\n",
    "from src.bongard_problems.data import get_bongard_open_world_labels"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3671d73e943fdc78",
   "metadata": {},
   "source": [
    "get_bongard_open_world_labels(\"../data/raw/bongard_open_world_labels.csv\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "24cd25d6f396bf12",
   "metadata": {},
   "source": [
    "dataset_dir_source = \"~/Datasets/BongardOpenWorld\"\n",
    "dataset_dir_target = \"~/Projects/llm-avr-benchmarks/data/raw/bongard_open_world_splitted\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "57ca8e7bd49efd5e",
   "metadata": {},
   "source": "## Copy images from global dataset directory to local splitted directory"
  },
  {
   "cell_type": "code",
   "id": "60ebd8be4256ece8",
   "metadata": {},
   "source": [
    "problem_ids = list(range(100, 200))\n",
    "for id in tqdm(problem_ids):\n",
    "    problem_dir = f\"{dataset_dir_source}/{id:04d}\"\n",
    "    if os.path.isdir(problem_dir):\n",
    "        num_left_images, num_right_images = 0, 0\n",
    "        for filename in sorted(os.listdir(problem_dir)):\n",
    "            source_path = os.path.join(problem_dir, filename)\n",
    "\n",
    "            side = \"left\" if filename.startswith(\"pos\") else \"right\"\n",
    "            target_dir = os.path.join(dataset_dir_target, str(id + 1), side)\n",
    "            os.makedirs(target_dir, exist_ok=True)\n",
    "\n",
    "            _, extension = os.path.splitext(filename)\n",
    "            if side == \"left\":\n",
    "                target_path = os.path.join(target_dir, f\"{num_left_images}{extension}\")\n",
    "                num_left_images += 1\n",
    "            else:\n",
    "                target_path = os.path.join(target_dir, f\"{num_right_images}{extension}\")\n",
    "                num_right_images += 1\n",
    "\n",
    "            copyfile(source_path, target_path)\n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Problem directory {problem_dir} doesn't exist\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "f516ef66e1b07174",
   "metadata": {},
   "source": [
    "## Resize each image so that the bigger dimension (height or width) is at most 512px"
   ]
  },
  {
   "cell_type": "code",
   "id": "658e4f58e7e6da89",
   "metadata": {},
   "source": [
    "max_size = 512\n",
    "# for problem_id in tqdm(sorted(os.listdir(dataset_dir_target))):\n",
    "for problem_id in problem_ids:\n",
    "    for side in [\"left\", \"right\"]:\n",
    "        for filename in os.listdir(f\"{dataset_dir_target}/{problem_id}/{side}\"):\n",
    "            filepath = f\"{dataset_dir_target}/{problem_id}/{side}/{filename}\"\n",
    "            try:\n",
    "                image = PIL.Image.open(filepath)\n",
    "\n",
    "                width, height = image.size[0], image.size[1]\n",
    "                if width <= max_size and height <= max_size:\n",
    "                    continue\n",
    "\n",
    "                if width > height:\n",
    "                    new_width = max_size\n",
    "                    new_height = int((height / width) * max_size)\n",
    "                else:\n",
    "                    new_width = int((width / height) * max_size)\n",
    "                    new_height = max_size\n",
    "\n",
    "                image = image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "                image.save(filepath)\n",
    "            except Exception as e:\n",
    "                print(filepath, e)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "c8219537e7d82f89",
   "metadata": {},
   "source": [
    "### Notes\n",
    "\n",
    "`~/Datasets/BongardOpenWorld/0001/neg__2__2023-03-18-12-16-30__https:__www.nasa.gov__sites__default__files__thumbnails__image__deepspaceexploration.jpg` is actually png, not jpg. I renamed it's extension locally before running the notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dadd096d62aa704b",
   "metadata": {},
   "source": [
    "## Arrange into the whole matrix such that its area is minimized; Resize to at most 1024px per dimension"
   ]
  },
  {
   "cell_type": "code",
   "id": "4beaa8455a173c90",
   "metadata": {},
   "source": [
    "permutations = list(itertools.permutations(range(6), 6))\n",
    "\n",
    "\n",
    "def get_size(\n",
    "        sizes: List[Tuple[int, int]],\n",
    "        permutation: List[int]\n",
    ") -> Tuple[int, int]:\n",
    "    p = permutation\n",
    "    width = max([\n",
    "        sizes[p[0]][0] + sizes[p[1]][0],\n",
    "        sizes[p[2]][0] + sizes[p[3]][0],\n",
    "        sizes[p[4]][0] + sizes[p[5]][0],\n",
    "    ])\n",
    "    height = max([\n",
    "        sizes[p[0]][1] + sizes[p[2]][1] + sizes[p[4]][1],\n",
    "        sizes[p[1]][1] + sizes[p[3]][1] + sizes[p[5]][1],\n",
    "    ])\n",
    "    return width, height\n",
    "\n",
    "\n",
    "def get_permutations(\n",
    "        left_sizes: List[Tuple[int, int]],\n",
    "        right_sizes: List[Tuple[int, int]]\n",
    ") -> Tuple[List[int], List[int]]:\n",
    "    min_area = math.inf\n",
    "    best_permutations = None\n",
    "    for left_permutation in permutations:\n",
    "        for right_permutation in permutations:\n",
    "            left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "            right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "            width = left_width + right_width\n",
    "            height = max(left_height, right_height)\n",
    "            area = width * height\n",
    "            if area < min_area:\n",
    "                min_area = area\n",
    "                best_permutations = (left_permutation, right_permutation)\n",
    "    return best_permutations"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ad94ab0b1767a6cb",
   "metadata": {},
   "source": [
    "def draw_side(\n",
    "        width: int,\n",
    "        height: int,\n",
    "        margin: int,\n",
    "        permutation: List[int],\n",
    "        sizes: List[Tuple[int, int]],\n",
    "        images: List[PIL.Image]\n",
    ") -> PIL.Image:\n",
    "    canvas = PIL.Image.new('RGB', (margin + width, 2 * margin + height))\n",
    "    y1, y2 = 0, 0\n",
    "    for i, p in enumerate(permutation):\n",
    "        image_width, image_height = sizes[p]\n",
    "        image = images[p]\n",
    "        x = 0 if i % 2 == 0 else width + margin - image_width\n",
    "        if i % 2 == 0:\n",
    "            y = y1\n",
    "            y1 += image_height + margin\n",
    "        else:\n",
    "            y = y2\n",
    "            y2 += image_height + margin\n",
    "        canvas.paste(image, (x, y))\n",
    "    return canvas"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "8a5dd3e83057a3c2",
   "metadata": {},
   "source": [
    "def resize(image: PIL.Image, max_size: int) -> PIL.Image:\n",
    "    width, height = image.size[0], image.size[1]\n",
    "    if width > max_size or height > max_size:\n",
    "        if width > height:\n",
    "            new_width = max_size\n",
    "            new_height = int((height / width) * max_size)\n",
    "        else:\n",
    "            new_width = int((width / height) * max_size)\n",
    "            new_height = max_size\n",
    "        image = image.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "    return image"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6c6e2893dfecace7",
   "metadata": {},
   "source": [
    "margin = 10\n",
    "side_max_size = 512\n",
    "max_size = 1024\n",
    "# for problem_id in tqdm(sorted(os.listdir(dataset_dir_target))):\n",
    "for problem_id in tqdm(problem_ids):\n",
    "    left_images = [\n",
    "        PIL.Image.open(f\"{dataset_dir_target}/{problem_id}/left/{filename}\")\n",
    "        for filename in os.listdir(f\"{dataset_dir_target}/{problem_id}/left\")\n",
    "    ]\n",
    "\n",
    "    right_images = [\n",
    "        PIL.Image.open(f\"{dataset_dir_target}/{problem_id}/right/{filename}\")\n",
    "        for filename in os.listdir(f\"{dataset_dir_target}/{problem_id}/right\")\n",
    "    ]\n",
    "\n",
    "    left_sizes = [(image.size[0], image.size[1]) for image in left_images]\n",
    "    right_sizes = [(image.size[0], image.size[1]) for image in right_images]\n",
    "\n",
    "    left_permutation, right_permutation = get_permutations(left_sizes, right_sizes)\n",
    "    left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "    right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "\n",
    "    left_canvas = draw_side(margin + left_width, 2 * margin + left_height, margin, left_permutation, left_sizes, left_images)\n",
    "    right_canvas = draw_side(margin + right_width, 2 * margin + right_height, margin, right_permutation, right_sizes, right_images)\n",
    "    \n",
    "    left_canvas = resize(left_canvas, side_max_size)\n",
    "    right_canvas = resize(right_canvas, side_max_size)\n",
    "\n",
    "    total_width = 4 * margin + left_width + right_width\n",
    "    total_height = 2 * margin + max(left_height, right_height)\n",
    "    canvas = PIL.Image.new('RGB', (total_width, total_height))\n",
    "\n",
    "    # Draw whole\n",
    "    for (side, images, sizes, permutation) in [\n",
    "        (\"left\", left_images, left_sizes, left_permutation),\n",
    "        (\"right\", right_images, right_sizes, right_permutation),\n",
    "    ]:\n",
    "        x = 0 if side == \"left\" else left_width\n",
    "        y1, y2 = 0, 0\n",
    "        for i, p in enumerate(permutation):\n",
    "            image_width, image_height = sizes[p]\n",
    "            image = images[p]\n",
    "\n",
    "            if i % 2 == 0:\n",
    "                y = y1\n",
    "                y1 += image_height + margin\n",
    "            else:\n",
    "                y = y2\n",
    "                y2 += image_height + margin\n",
    "\n",
    "            if side == \"left\":\n",
    "                x = 0 if i % 2 == 0 else left_width + margin - image_width\n",
    "            else:\n",
    "                x = left_width + 3 * margin if i % 2 == 0 else total_width - image_width\n",
    "\n",
    "            draw = PIL.ImageDraw2.Draw(canvas)\n",
    "            draw.line((left_width + 2 * margin, 0, left_width + 2 * margin, total_height), Pen(color=\"white\", width=5))\n",
    "\n",
    "            canvas.paste(image, (x, y))\n",
    "\n",
    "    if total_width > max_size or total_height > max_size:\n",
    "        if total_width > total_height:\n",
    "            new_width = max_size\n",
    "            new_height = int((total_height / total_width) * max_size)\n",
    "        else:\n",
    "            new_width = int((total_width / total_height) * max_size)\n",
    "            new_height = max_size\n",
    "        canvas = canvas.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "\n",
    "    left_canvas.save(f\"{dataset_dir_target}/{problem_id}/left.png\")\n",
    "    right_canvas.save(f\"{dataset_dir_target}/{problem_id}/right.png\")\n",
    "    canvas.save(f\"{dataset_dir_target}/{problem_id}/whole.png\")\n",
    "    # canvas.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "a18d40ea13e79fab",
   "metadata": {},
   "source": [
    "## Arrange but with white background"
   ]
  },
  {
   "cell_type": "code",
   "id": "8bbaa3087b3b8a14",
   "metadata": {},
   "source": [
    "margin = 10\n",
    "max_size = 1024\n",
    "for problem_id in [42, 62]:\n",
    "    left_images = [\n",
    "        PIL.Image.open(f\"{dataset_dir_target}/{problem_id}/left/{filename}\")\n",
    "        for filename in os.listdir(f\"{dataset_dir_target}/{problem_id}/left\")\n",
    "    ]\n",
    "\n",
    "    right_images = [\n",
    "        PIL.Image.open(f\"{dataset_dir_target}/{problem_id}/right/{filename}\")\n",
    "        for filename in os.listdir(f\"{dataset_dir_target}/{problem_id}/right\")\n",
    "    ]\n",
    "\n",
    "    left_sizes = [(image.size[0], image.size[1]) for image in left_images]\n",
    "    right_sizes = [(image.size[0], image.size[1]) for image in right_images]\n",
    "\n",
    "    left_permutation, right_permutation = get_permutations(left_sizes, right_sizes)\n",
    "    left_width, left_height = get_size(left_sizes, left_permutation)\n",
    "    right_width, right_height = get_size(right_sizes, right_permutation)\n",
    "\n",
    "    total_width = 4 * margin + left_width + right_width\n",
    "    total_height = 2 * margin + max(left_height, right_height)\n",
    "    canvas = PIL.Image.new('RGB', (total_width, total_height), color=\"white\")\n",
    "\n",
    "    # Draw whole\n",
    "    for (side, images, sizes, permutation) in [\n",
    "        (\"left\", left_images, left_sizes, left_permutation),\n",
    "        (\"right\", right_images, right_sizes, right_permutation),\n",
    "    ]:\n",
    "        x = 0 if side == \"left\" else left_width\n",
    "        y1, y2 = 0, 0\n",
    "        for i, p in enumerate(permutation):\n",
    "            image_width, image_height = sizes[p]\n",
    "            image = images[p]\n",
    "\n",
    "            if i % 2 == 0:\n",
    "                y = y1\n",
    "                y1 += image_height + margin\n",
    "            else:\n",
    "                y = y2\n",
    "                y2 += image_height + margin\n",
    "\n",
    "            if side == \"left\":\n",
    "                x = 0 if i % 2 == 0 else left_width + margin - image_width\n",
    "            else:\n",
    "                x = left_width + 3 * margin if i % 2 == 0 else total_width - image_width\n",
    "\n",
    "            draw = PIL.ImageDraw2.Draw(canvas)\n",
    "            draw.line((left_width + 2 * margin, 0, left_width + 2 * margin, total_height), Pen(color=\"black\", width=5))\n",
    "\n",
    "            canvas.paste(image, (x, y))\n",
    "\n",
    "    if total_width > max_size or total_height > max_size:\n",
    "        if total_width > total_height:\n",
    "            new_width = max_size\n",
    "            new_height = int((total_height / total_width) * max_size)\n",
    "        else:\n",
    "            new_width = int((total_width / total_height) * max_size)\n",
    "            new_height = max_size\n",
    "        canvas = canvas.resize((new_width, new_height), PIL.Image.Resampling.LANCZOS)\n",
    "\n",
    "    canvas.save(f\"{dataset_dir_target}/{problem_id}/whole-white.png\")\n",
    "    # canvas.show()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "98234d453f5269e4",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
