{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5a5c08e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import glob\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from arc_utils.utils import GridConverter, hamming_distance\n",
    "from matplotlib.colors import ListedColormap, Normalize\n",
    "from tqdm import tqdm\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams['text.usetex'] = True  # Enable full LaTeX\n",
    "\n",
    "\n",
    "def get_bb_score(completion1, completion2):\n",
    "   return 1 - hamming_distance(completion1, completion2)\n",
    "\n",
    "with open('kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json', 'r') as file:\n",
    "    problems = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c21fcd88",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tasks Solved: 89\n"
     ]
    }
   ],
   "source": [
    "# All Tasks\n",
    "all_tasks = ['00576224', '009d5c81', '00dbd492', '03560426', '05a7bcf2', '0607ce86', '0692e18c', '070dd51e', '08573cc6', '0934a4d8', '09c534e7', '0a1d4ef5', '0a2355a6', '0b17323b', '0bb8deee', '0becf7df', '0c786b71', '0c9aba6e', '0d87d2a6', '0e671a1a', '0f63c0b9', '103eff5b', '11e1fe23', '12422b43', '12997ef3', '12eac192', '136b0064', '13713586', '137f0df0', '140c817e', '14754a24', '15113be4', '15663ba9', '15696249', '16b78196', '17b80ad2', '17cae0c1', '18419cfa', '184a9768', '195ba7dc', '1990f7a8', '19bb5feb', '1a2e2828', '1a6449f1', '1acc24af', '1c02dbbe', '1c0d0a4b', '1c56ad9f', '1d0a4b61', '1d398264', '1da012fc', '1e81d6f9', '1e97544e', '2037f2c7', '2072aba6', '20818e16', '20981f0e', '212895b5', '21f83797', '22a4bbc2', '25094a63', '2546ccf6', '256b0a75', '2685904e', '2697da3f', '2753e76c', '27a77e38', '27f8ce4f', '281123b4', '292dd178', '29700607', '2a5f8217', '2b01abd0', '2c0b0aff', '2c737e39', '2f0c5170', '310f3251', '3194b014', '319f2597', '31adaf00', '31d5ba1a', '32e9702f', '332efdb3', '3391f8c0', '33b52de3', '3490cc26', '34b99a2b', '351d6448', '358ba94e', '37d3e8b2', '3979b1a8', '3a301edc', '3b4c2228', '3d31c5b3', '3ed85e70', '3ee1011a', '3f23242b', '40f6cd08', '414297c0', '423a55dc', '42918530', '42a15761', '4364c1c4', '456873bc', '45737921', '45bbe264', '477d2879', '47996f11', '48131b3c', '4852f2fa', '48f8583b', '4aab4007', '4acc7107', '4b6b68e5', '4c177718', '4cd1b7b2', '4e45f183', '4e469f39', '4f537728', '4ff4c9da', '505fff84', '506d28a5', '50a16a69', '50aad11f', '50f325b5', '516b51b7', '5207a7b5', '5289ad53', '52fd389e', '54db823b', '55059096', '551d5bf1', '55783887', '575b1a71', '5783df64', '5833af48', '58743b76', '58e15b12', '59341089', '5a5a2103', '5af49b42', '5b526a93', '5b692c0f', '5b6cbef5', '5d2a5c43', '5ffb2104', '604001fa', '60a26a3e', '60c09cac', '626c0bcc', '62ab2642', '62b74c02', '639f5a19', '642248e4', '642d658d', '64a7c07e', '66e6c45b', '66f2d22f', '67636eac', '67b4a34d', '67c52801', '68b67ca3', '692cd3b6', '695367ec', '696d4842', '69889d6e', '6a11f6da', '6ad5bdfd', '6df30ad6', '6ea4a07e', '6f473927', '7039b2d7', '705a3229', '712bf12e', '72207abc', '72a961c9', '73182012', '73c3b0d8', '73ccf9c2', '759f3fd3', '762cd429', '770cc55f', '782b5218', '79369cc6', '7953d61e', '79fb03f4', '7bb29440', '7c8af763', '7c9b52a0', '7d18a6fb', '7d1f7ee8', '7d419a02', '7e02026e', '7ee1c6ea', '817e6c09', '81c0276b', '833dafe3', '845d6e51', '84db8fc4', '84f2aca1', '8597cfd7', '85b81ff1', '85fa5666', '8719f442', '88207623', '891232d6', '896d5239', '8a371977', '8b28cd80', '8ba14f53', '8cb8642d', '8dae5dfc', '8e2edd66', '8ee62060', '8fbca751', '90347967', '903d1b4a', '9110e3c5', '917bccba', '929ab4e9', '92e50de0', '9356391f', '93b4f4b3', '93c31fbe', '94133066', '94414823', '94be5b80', '95a58926', '963f59bc', '96a8c0cd', '97239e3d', '9772c176', '981571dc', '992798f6', '99306f82', '9a4bb226', '9b2a60aa', '9b365c51', '9b4c17c4', '9bebae7a', '9c1e755f', '9c56f360', '9caba7c3', '9ddd00f0', '9def23fe', '9f27f097', 'a04b2602', 'a096bf4d', 'a3f84088', 'a406ac07', 'a57f2f04', 'a59b95c0', 'a680ac02', 'a8610ef7', 'a934301b', 'aa18de87', 'aa300dc3', 'aa4ec2a5', 'aab50785', 'ac0c5833', 'ac2e8ecf', 'ac3e2b04', 'ac605cbb', 'ad7e01d0', 'ae58858e', 'aee291af', 'af22c60d', 'af24b4cc', 'b0722778', 'b0f4d537', 'b15fca0b', 'b1fc8b8e', 'b20f7c8b', 'b457fec5', 'b4a43f3b', 'b7999b51', 'b7cb93ac', 'b7f8a4d8', 'b7fb29bc', 'b942fd60', 'b9630600', 'ba9d41b8', 'baf41dbf', 'bb52a14b', 'bbb1b8b6', 'bc4146bd', 'bcb3040b', 'bd14c3bf', 'be03b35f', 'bf32578f', 'bf699163', 'bf89d739', 'c074846d', 'c1990cce', 'c3202e5a', 'c35c1b4c', 'c48954c1', 'c62e2108', 'c64f1187', 'c658a4bd', 'c663677b', 'c6e1b8da', 'c7d4e6ad', 'c87289bb', 'c8b7cc0f', 'c92b942c', 'c97c0139', 'ca8de6ea', 'ca8f78db', 'cad67732', 'cb227835', 'ccd554ac', 'cd3c21df', 'ce039d91', 'ce8d95cc', 'cf133acc', 'cfb2ce5a', 'd017b73f', 'd19f7514', 'd282b262', 'd2acf2cb', 'd304284e', 'd37a1ef5', 'd47aa2ff', 'd492a647', 'd4b1c2b1', 'd4c90558', 'd56f2372', 'd5c634a2', 'd931c21c', 'd94c3b52', 'da2b0fe3', 'da515329', 'dc2aa30b', 'dc2e9a9d', 'dd2401ed', 'de493100', 'df8cc377', 'e0fb7511', 'e133d23d', 'e1baa8a4', 'e1d2900e', 'e2092e0c', 'e21a174a', 'e345f17b', 'e4075551', 'e41c6fd3', 'e57337a4', 'e5790162', 'e5c44e8f', 'e619ca6e', 'e633a9e5', 'e66aafb8', 'e681b708', 'e69241bd', 'e6de6e8f', 'e74e1818', 'e760a62e', 'e7639916', 'e78887d1', 'e7a25a18', 'e7b06bea', 'e7dd8335', 'e872b94a', 'e88171ec', 'e95e3d8e', 'e99362f0', 'e9ac8c9e', 'e9b4f6fc', 'e9bb6954', 'e9c9d9a1', 'ea959feb', 'ea9794b1', 'ecaa0ec1', 'ed74f2f2', 'ed98d772', 'ef26cbf6', 'f0afb749', 'f0df5ff0', 'f21745ec', 'f3b10344', 'f3cdc58f', 'f3e62deb', 'f4081712', 'f45f5ca7', 'f5aa3634', 'f5c89df1', 'f823c43c', 'f83cb3f6', 'f8be4b64', 'f9a67cb5', 'f9d67f8b', 'fafd9572', 'fb791726', 'fc754716', 'fd096ab6', 'fd4b2b02', 'fe9372f3', 'fea12743', 'ff72ca3e']\n",
    "# Small Tasks\n",
    "small_tasks = ['00576224', '0c786b71', '0c9aba6e', 'o17cae0c1', '195ba7dc', '2072aba6', '27a77e38', '281123b4', '31d5ba1a', '32e9702f', '34b99a2b', '3b4c2228', '3d31c5b3', '48131b3c', '4852f2fa', '4cd1b7b2', '506d28a5', '5783df64', '59341089', '5d2a5c43', '60c09cac', '626c0bcc', '62b74c02', '66e6c45b', '66f2d22f', '68b67ca3', '6a11f6da', '6ad5bdfd', '6ea4a07e', '7953d61e', '833dafe3', '8597cfd7', '8ba14f53', '9110e3c5', 'a8610ef7', 'aa18de87', 'af24b4cc', 'b0722778', 'b1fc8b8e', 'bbb1b8b6', 'be03b35f', 'c074846d', 'c8b7cc0f', 'ca8de6ea', 'd017b73f', 'd19f7514', 'e133d23d', 'e345f17b', 'e633a9e5', 'e6de6e8f', 'e99362f0', 'ed74f2f2', 'ed98d772', 'fc754716']\n",
    "\n",
    "def get_solved(tasks):\n",
    "    solved = []\n",
    "    for task in tasks:\n",
    "        files = glob.glob(f\"remote/induction_400/greedy_ns/{task}/*.log\")\n",
    "        with open(files[0], 'r') as file:\n",
    "            data = json.load(file)\n",
    "            if 'pass2' in data and data['pass2'] and data['pass2'][0] is not None:\n",
    "                if np.mean([x['score'] for x in data['pass2']]) == 1:\n",
    "                    solved.append(task)\n",
    "    return solved\n",
    "\n",
    "migrate_solved  = get_solved(all_tasks)\n",
    "print(\"Tasks Solved:\", len(migrate_solved))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ead2e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import contextlib\n",
    "import io\n",
    "gridConverter = GridConverter(use_barc_format=True, use_induction=True)\n",
    "\n",
    "unsolved = list(set(all_tasks) - set(migrate_solved))\n",
    "\n",
    "solved_unsolved_pairs = []\n",
    "for task in tqdm(unsolved):\n",
    "    train_inputs = [np.array(x['input']) for x in problems[task]['train']]\n",
    "    train_outputs = [np.array(x['output']) for x in problems[task]['train']]\n",
    "    best_solved = []\n",
    "    for t in migrate_solved:\n",
    "        files = glob.glob(f\"remote/induction_400/greedy_ns/{t}/*.log\")\n",
    "        with open(files[0], 'r') as file:\n",
    "            data = json.load(file)\n",
    "            program = [x for x in data['test_samples'] if x['train_score'] == 1][0]['code']\n",
    "\n",
    "        scores = []\n",
    "        for input, output in zip(train_inputs, train_outputs):\n",
    "            with contextlib.redirect_stdout(io.StringIO()):\n",
    "                guess = gridConverter.decode(program, input)\n",
    "            scores.append(get_bb_score(output, guess))\n",
    "        best_solved.append((t, np.mean(scores)))\n",
    "    best_solved = sorted(best_solved, key=lambda x : x[1], reverse=True)\n",
    "    solved_unsolved_pairs.append({'unsolved': task, 'solved': best_solved[0][0]})\n",
    "\n",
    "print(\"Unsolved task:\", [x['unsolved'] for x in solved_unsolved_pairs])\n",
    "print(\"Bootstrapped weights from:\", [x['solved'] for x in solved_unsolved_pairs])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
