{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "def download_image_from_url(url, filename):\n",
    "    img_data = requests.get(url).content\n",
    "    with open(filename, 'wb') as handler:\n",
    "        handler.write(img_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Hyperparameters of the experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define Hyperparameters of the experiment\n",
    "gradio_link = \"https://.....gradio.live\"\n",
    "\n",
    "# about the LVLM using AWS Bedrock\n",
    "aws_access_key_id= \"yours_aws_access_key_id\",\n",
    "aws_secret_access_key= \"yours_aws_secret_access_key\", \n",
    "region_name= \"yours_region_name\"\n",
    "model_id = \"anthropic.claude-3-5-sonnet-20241022-v2:0\" # model_id of the LVLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LVLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "# Define the aws runtime and model\n",
    "bedrock_runtime_client = boto3.client(\n",
    "    'bedrock-runtime',\n",
    "    aws_access_key_id= aws_access_key_id,\n",
    "    aws_secret_access_key= aws_secret_access_key,\n",
    "    region_name=region_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate the optimal semantic edits "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cece\n",
    "from cece.queries import *\n",
    "from cece.refine import *\n",
    "from cece.wordnet import *\n",
    "\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# import sys\n",
    "from BDD100k_classifier import BDD100k_classifier\n",
    "\n",
    "classifier = BDD100k_classifier()\n",
    "\n",
    "from claude_predictor import *\n",
    "\n",
    "# classification_prompt = f\"\"\"\n",
    "# Classify each image in their appropriate class according to the driving situation they depict. \n",
    "# Valid class labels are 'start' or 'stop' and only these, depending on whether the car has to move or stop based on its surroundings.\n",
    "# You need to classify the images in one of these classes.\n",
    "# Pay attention to the semantics that define each class.\n",
    "# Return me only the label of the scene depicted and nothing else.\n",
    "# \"\"\"\n",
    "\n",
    "# prompt_analyze = \"\"\"\n",
    "# Please analyze the images in detail and answer the following question with reason based on these images. \n",
    "# \"\"\"\n",
    "\n",
    "# text_prompt = f\"\"\"\n",
    "# Based on your analysis above, classify each image in their appropriate class according to the driving situation they depict. \n",
    "# Valid class labels are 'start' or 'stop' and only these, depending on whether the car has to move or stop based on its surroundings.\n",
    "# Pay attention to the semantics that define each class.\n",
    "# You need to classify the images in one of these classes.\n",
    "# Return me only the label of the scene depicted and nothing else.\n",
    "# \"\"\"\n",
    "\n",
    "\n",
    "image_names = os.listdir(\"bdd100k/images/10k/train/\")\n",
    "\n",
    "\n",
    "# class Claude_classifier: \n",
    "    \n",
    "#     def __init__(self):\n",
    "#         pass\n",
    "#     def classify(self, image_name):\n",
    "#         source_classes = defaultdict(list)\n",
    "#         pred = predict_classes_claude([image_name], source_classes, \n",
    "#                                                 classification_prompt, prompt_analyze, text_prompt, analyze=False)\n",
    "#         pred = pred[image_name][0]\n",
    "#         if pred == \"stop\":\n",
    "#             return 0\n",
    "#         else:\n",
    "#             return 1\n",
    "\n",
    "# classifier = Claude_classifier()    \n",
    "classifier.classify(\"bdd100k/images/10k/train/\" + image_names[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "\n",
    "with open (\"bdd100k/labels/sem_seg/rles/sem_seg_train.json\") as handle:\n",
    "    segs = json.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "7000it [14:20,  8.13it/s]\n"
     ]
    }
   ],
   "source": [
    "dataset = []\n",
    "labels = []\n",
    "index_to_image_id = {}\n",
    "image_id_to_index = {}\n",
    "\n",
    "for i, row in tqdm(enumerate(segs[\"frames\"])):\n",
    "    objs = []\n",
    "    for obj in row[\"labels\"]:\n",
    "        objs.append(obj[\"category\"])\n",
    "    \n",
    "    dataset.append(objs.copy())\n",
    "    image_id = row[\"name\"]\n",
    "    labels.append(classifier.classify(os.path.join(\"bdd100k/images/10k/train\", image_id)))\n",
    "#     labels.append(labels_pickle[image_id])\n",
    "    index_to_image_id[i] = image_id\n",
    "    image_id_to_index[image_id] = i\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def export_text_edits(edits):\n",
    "    transf = []\n",
    "    for e1, e2 in edits[\"transf\"]:\n",
    "        ee1, ee2 = None, None\n",
    "        for e in e1:\n",
    "            if \".\" not in e:\n",
    "                ee1 = e\n",
    "                break\n",
    "                \n",
    "        for e in e2:\n",
    "            if \".\" not in e:\n",
    "                ee2 = e\n",
    "                break\n",
    "        transf.append([ee1, ee2])\n",
    "        \n",
    "    return {\n",
    "        \"additions\": [ee for e in edits[\"additions\"] for ee in e if \".\" not in ee],\n",
    "        \"removals\": [ee for e in edits[\"removals\"] for ee in e if \".\" not in ee],\n",
    "        \"transf\": transf\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cece.xDataset import *\n",
    "from cece.xDataset import createMSQ\n",
    "\n",
    "\n",
    "msq_dataset = []\n",
    "for row in dataset:\n",
    "    msq = []\n",
    "    for obj in row:\n",
    "        try:\n",
    "            msq.append(connect_term_to_wordnet(obj).union([obj]))\n",
    "        except:\n",
    "            try:\n",
    "                msq.append(connect_term_to_wordnet(obj.replace(\" \", \"\")).union([obj.replace(\" \", \"\")]))\n",
    "            except:\n",
    "                pass\n",
    "    msq_dataset.append(msq.copy())\n",
    "\n",
    "ds = xDataset(dataset = msq_dataset,\n",
    "              labels = labels,\n",
    "              connect_to_wordnet = False)\n",
    "\n",
    "def get_local_edits(image_id):\n",
    "    source_index = image_id_to_index[image_id]\n",
    "    source_image_id = index_to_image_id[source_index]\n",
    "    objects_source = [dd for d in ds.dataset[source_index].concepts for dd in d if \".\" not in dd]\n",
    "\n",
    "    target_index, cost = ds.explain(ds.dataset[source_index], labels[source_index])\n",
    "    target_image_id = index_to_image_id[target_index]\n",
    "    cost, edits = ds.find_edits(ds.dataset[source_index], ds.dataset[target_index])\n",
    "    \n",
    "    edits = export_text_edits(edits)\n",
    "    added_objs = edits[\"additions\"] + [e for [_, e] in edits[\"transf\"]]\n",
    "    removed_objs = edits[\"removals\"] + [e for [e, _] in edits[\"transf\"]]\n",
    "    return objects_source, added_objs, removed_objs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pretrained Stable Diffusion Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edits import Edits\n",
    "import ast\n",
    "from editor import Editor\n",
    "import boto3\n",
    "from chat import Chat\n",
    "\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "editor = Editor(gradio_link)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def global_explanations(source_image_path, orig_label):\n",
    "\n",
    "    regional_dataset = []\n",
    "    regional_labels = []\n",
    "    for l, r in zip(labels, dataset):\n",
    "        if l == orig_label:\n",
    "            regional_dataset.append(r)\n",
    "            regional_labels.append(l)\n",
    "\n",
    "    gl = ds.global_explanation(regional_dataset, regional_labels)\n",
    "    return {k: v for k, v in gl.items() if \".\" not in k}\n",
    "\n",
    "with open('global_explanations_0_classifier_classifier.pickle', 'rb') as handle:\n",
    "    global_explanations_0 = pickle.load(handle)\n",
    "    \n",
    "with open('global_explanations_1_classifier_classifier.pickle', 'rb') as handle:\n",
    "    global_explanations_1 = pickle.load(handle)\n",
    "    \n",
    "\n",
    "def global_explanations(source_image_path, orig_label):\n",
    "    if orig_label == 0:\n",
    "        return global_explanations_0\n",
    "    else:\n",
    "        return global_explanations_1\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude decide the edits!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prompts import prompt_single_step, prompt_add_object, prompt_remove_object\n",
    "\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "def create_or_replace_dir(directory_name):\n",
    "    # Check if the directory already exists\n",
    "    if os.path.exists(directory_name):\n",
    "        # If it exists, remove it\n",
    "        shutil.rmtree(directory_name)\n",
    "    \n",
    "    # Create the new directory\n",
    "    os.makedirs(directory_name)\n",
    "\n",
    "def edit_global_edits(image_id):\n",
    "    \n",
    "    create_or_replace_dir(f\"imgs/bdd100k/classifier/global/{image_id}\")\n",
    "    source_image_path = f\"imgs/bdd100k/classifier/global/{image_id}/source.jpg\"\n",
    "    \n",
    "    url = \"dd100k/images/10k/train/\" + image_id #data[image_id][\"url\"]\n",
    "#     download_image_from_url(url, source_image_path)\n",
    "    shutil.copyfile(url, source_image_path)\n",
    "    steps = []\n",
    "    \n",
    "    objs, added_objs, removed_objs = get_local_edits(image_id)\n",
    "    \n",
    "    chat = Chat(model_id, bedrock_runtime_client)\n",
    "\n",
    "    logs = \"\"\n",
    "    excs, i = 0, 1\n",
    "    orig_label = classifier.classify(source_image_path)\n",
    "    new_label = orig_label\n",
    "    global_edits = global_explanations(source_image_path, orig_label)\n",
    "    logs += f\"Classification: {orig_label}\\n\"\n",
    "    for obj, v in global_edits.items():\n",
    "        try:\n",
    "            \n",
    "            if  v <= 0:\n",
    "                \n",
    "                if obj in objs:\n",
    "                    prompt = prompt_remove_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    background = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{background}\\n\"\n",
    "                    background = background.strip()\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, obj, background)\n",
    "                    logs += f\"\\n{['remove', obj, background]}\\n\" \n",
    "                    steps.append([\"remove\", obj, background])\n",
    "                    \n",
    "                    source_image_path = f\"imgs/bdd100k/classifier/global/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    new_label = classifier.classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "\n",
    "\n",
    "            elif v > 0:    \n",
    "                \n",
    "                if obj not in objs: \n",
    "                    prompt = prompt_add_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    add = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{add}\\n\"\n",
    "                    add = add.strip()\n",
    "\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, add, obj)\n",
    "                    logs += f\"\\n{['add', obj, add]}\\n\" \n",
    "                    steps.append([\"add\", obj, add])\n",
    "\n",
    "                    source_image_path = f\"imgs/bdd100k/classifier/global/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    new_label = classifier.classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "\n",
    "        except Exception as e:\n",
    "            excs += 1\n",
    "            logs += f\"Exception: {e}\\n\"\n",
    "            if excs >= 5:\n",
    "                break\n",
    "                \n",
    "                \n",
    "        if (new_label != orig_label):\n",
    "            break\n",
    "\n",
    "\n",
    "    logs += f\"\\n\\n----\\n\\n{steps}\\n\\n----\\n\\n\"\n",
    "    with open(f\"imgs/bdd100k/classifier/global/{image_id}/logs.txt\", \"w\") as handle:\n",
    "        handle.write(logs)\n",
    "    return steps\n",
    "        \n",
    "    \n",
    "# def classified_as(image_path, cl):\n",
    "#     preds = classifier.classify(image_path)\n",
    "#     if preds == cl:\n",
    "#         return True\n",
    "#     return False \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█▎        | 69/543 [4:05:13<26:37:57, 202.27s/it]"
     ]
    }
   ],
   "source": [
    "for key in image_id_to_index:\n",
    "    if not os.path.exists(f\"imgs/bdd100k/claude-sonnet/global/{key}/step_1.jpg\"):\n",
    "        edit_global_edits(key)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv_3.10.12",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
