{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pptx import Presentation\n",
    "from pptx.enum.text import PP_ALIGN\n",
    "from pptx.util import Inches\n",
    "from pptx.util import Pt\n",
    "import os\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "def add_image_slide(prs, title, image_path, left=Inches(1), top=Inches(1), width=Inches(5)):\n",
    "    slide_layout = prs.slide_layouts[5]\n",
    "    slide = prs.slides.add_slide(slide_layout)\n",
    "    title_placeholder = slide.shapes.title\n",
    "    title_placeholder.text = title\n",
    "\n",
    "    img = Image.open(image_path)\n",
    "    img_width, img_height = img.size\n",
    "    \n",
    "    slide_width = prs.slide_width\n",
    "    slide_height = prs.slide_height\n",
    "    max_width = slide_width * 0.8\n",
    "    max_height = slide_height * 0.6\n",
    "    \n",
    "    if img_width / img_height >= max_width / max_height:\n",
    "        width = max_width\n",
    "        height = max_width * img_height / img_width\n",
    "    else:\n",
    "        height = max_height\n",
    "        width = max_height * img_width / img_height\n",
    "    \n",
    "    left = (slide_width - width) / 2\n",
    "    top = (slide_height - height) / 2\n",
    "    \n",
    "    slide.shapes.add_picture(image_path, left, top, width, height)\n",
    "\n",
    "def add_side_by_side_images_slide(prs, title, subtitle_text, image1_path, image2_path):\n",
    "    slide_layout = prs.slide_layouts[5]\n",
    "    slide = prs.slides.add_slide(slide_layout)\n",
    "    title_placeholder = slide.shapes.title\n",
    "    \n",
    "    title_placeholder.text = title\n",
    "    \n",
    "    left = Inches(0)\n",
    "    top = Inches(1.3)\n",
    "    width = Inches(5)\n",
    "    height = Inches(1)\n",
    "    textbox = slide.shapes.add_textbox(left, top, width, height)\n",
    "\n",
    "    text_frame = textbox.text_frame\n",
    "    text_frame.text = subtitle_text\n",
    "\n",
    "    for paragraph in text_frame.paragraphs:\n",
    "        paragraph.alignment = PP_ALIGN.CENTER\n",
    "        for run in paragraph.runs:\n",
    "            run.font.size = Pt(22)\n",
    "\n",
    "    slide_width = prs.slide_width\n",
    "    textbox_width = textbox.width\n",
    "    left = (slide_width - textbox_width) // 2\n",
    "\n",
    "    textbox.left = left\n",
    "    \n",
    "    img1 = Image.open(image1_path)\n",
    "    img2 = Image.open(image2_path)\n",
    "    img1_width, img1_height = img1.size\n",
    "    img2_width, img2_height = img2.size\n",
    "    \n",
    "    slide_width = prs.slide_width\n",
    "    slide_height = prs.slide_height \n",
    "    max_width = (slide_width * 0.46)\n",
    "    max_height = slide_height * 0.6\n",
    "    \n",
    "    if img1_width / img1_height >= max_width / max_height:\n",
    "        width1 = max_width\n",
    "        height1 = max_width * img1_height / img1_width\n",
    "    else:\n",
    "        height1 = max_height\n",
    "        width1 = max_height * img1_width / img1_height\n",
    "\n",
    "    if img2_width / img2_height >= max_width / max_height:\n",
    "        width2 = max_width\n",
    "        height2 = max_width * img2_height / img2_width\n",
    "    else:\n",
    "        height2 = max_height\n",
    "        width2 = max_height * img2_width / img2_height\n",
    "    \n",
    "    top1 = (slide_height - height1) / 2 + Inches(1)\n",
    "    top2 = (slide_height - height2) / 2 + Inches(1)\n",
    "    \n",
    "    left1 = Inches(0.2)\n",
    "    left2 = slide_width - width2 - Inches(0.2)\n",
    "    \n",
    "    slide.shapes.add_picture(image1_path, left1, top1, width1, height1)\n",
    "    slide.shapes.add_picture(image2_path, left2, top2, width2, height2)\n",
    "\n",
    "def create_presentation(ids, comparisons, group1, group2, output_file=\"output_presentation.pptx\"):\n",
    "    prs = Presentation()\n",
    "    \n",
    "    for (problem_id, comparison, img1, img2) in tqdm(zip(ids, comparisons, group1, group2)):\n",
    "        slide_layout = prs.slide_layouts[0]\n",
    "        slide = prs.slides.add_slide(slide_layout)\n",
    "        title = slide.shapes.title\n",
    "        title.text = f\"Problem {problem_id}\"\n",
    "        \n",
    "        if len(slide.placeholders) > 1:\n",
    "            subtitle = slide.placeholders[1]\n",
    "            sp = subtitle.element\n",
    "            sp.getparent().remove(sp)\n",
    "\n",
    "        add_image_slide(prs, f\"Problem {problem_id} - Bongard-RWR\", img1)\n",
    "\n",
    "        add_image_slide(prs, f\"Problem {problem_id} - Bongard\", img2)\n",
    "\n",
    "        add_side_by_side_images_slide(prs, f\"Problem {problem_id} - Comparison\", comparison, img1, img2)\n",
    "    \n",
    "    prs.save(output_file)\n",
    "    print(f\"Presentation saved as {output_file}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translated_bongard_problems = list(sorted(map(lambda x: int(x), os.listdir('../data/raw/bongard_rwr_splitted'))))\n",
    "\n",
    "bongard_rwr_images = list(map(lambda x: f\"../data/raw/bongard_rwr_splitted/{x}/whole-white.png\", translated_bongard_problems))\n",
    "bongard_images = list(map(lambda x: f\"../data/raw/bongard_splitted/{x}/whole.png\", translated_bongard_problems))\n",
    "\n",
    "indeces = [x - 1 for x in translated_bongard_problems]\n",
    "labels = pd.DataFrame(pd.read_csv('../data/raw/labels.csv'))\n",
    "comparisons = [f\"{left} - {right}\" for _, left, right in labels.iloc[indeces].values]\n",
    "\n",
    "create_presentation(translated_bongard_problems, comparisons, bongard_rwr_images, bongard_images, \"problem_presentation.pptx\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
