{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "def download_image_from_url(url, filename):\n",
    "    img_data = requests.get(url).content\n",
    "    with open(filename, 'wb') as handler:\n",
    "        handler.write(img_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Hyperparameters of the experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define Hyperparameters of the experiment\n",
    "gradio_link = \"https://.....gradio.live\"\n",
    "\n",
    "# about the LVLM using AWS Bedrock"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate the optimal semantic edits "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cece\n",
    "from cece.queries import *\n",
    "from cece.refine import *\n",
    "from cece.wordnet import *\n",
    "\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the data have been downloaded from this repo: https://github.com/aggeliki-dimitriou/SGCE \n",
    "# for fair comparison with the method\n",
    "with open(\"data/vg_data_random.pickle\", \"rb\") as handle:\n",
    "    data = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def export_text_edits(edits):\n",
    "    \"\"\"\n",
    "    Processes a dictionary of text edits by filtering and transforming elements based on specific criteria.\n",
    "\n",
    "    The function expects a dictionary with three keys: \"transf\", \"additions\", and \"removals\".\n",
    "    Each key should map to a list of elements (strings or list of strings).\n",
    "\n",
    "    The \"transf\" key is expected to contain a list of tuples, where each tuple represents a pair of word lists.\n",
    "    The function transforms each tuple by selecting the first word from each list in the pair that does not contain a period.\n",
    "\n",
    "    For the \"additions\" and \"removals\" keys, which map to lists of lists of strings, the function flattens these lists and includes only those elements that do not contain a period.\n",
    "\n",
    "    Parameters:\n",
    "        edits (dict): A dictionary containing three keys:\n",
    "            - \"transf\": A list of tuples, each containing two lists of words (word pairs).\n",
    "            - \"additions\": A list of lists, where each inner list contains words to be added.\n",
    "            - \"removals\": A list of lists, where each inner list contains words to be removed.\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary with the same structure as the input but filtered and transformed based on the criteria:\n",
    "            - \"additions\": List of words to be added, filtered to exclude words containing a period.\n",
    "            - \"removals\": List of words to be removed, filtered to exclude words containing a period.\n",
    "            - \"transf\": A list of transformed word pairs, each selected based on the absence of a period.\n",
    "    \"\"\"\n",
    "     \n",
    "    transf = []\n",
    "    for e1, e2 in edits[\"transf\"]:\n",
    "        ee1, ee2 = None, None\n",
    "        for e in e1:\n",
    "            if \".\" not in e:\n",
    "                ee1 = e\n",
    "                break\n",
    "                \n",
    "        for e in e2:\n",
    "            if \".\" not in e:\n",
    "                ee2 = e\n",
    "                break\n",
    "        transf.append([ee1, ee2])\n",
    "        \n",
    "    return {\n",
    "        \"additions\": [ee for e in edits[\"additions\"] for ee in e if \".\" not in ee],\n",
    "        \"removals\": [ee for e in edits[\"removals\"] for ee in e if \".\" not in ee],\n",
    "        \"transf\": transf\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cece.xDataset import *\n",
    "from cece.xDataset import createMSQ\n",
    "\n",
    "dataset = []\n",
    "labels = []\n",
    "index_to_image_id = {}\n",
    "image_id_to_index = {}\n",
    "for i, (k, row) in enumerate(data.items()):\n",
    "    msq = []\n",
    "    for o in row[\"objects\"]:\n",
    "        msq.append (connect_term_to_wordnet(o).union([o.split(\".\")[0]]))\n",
    "    \n",
    "    dataset.append(msq)\n",
    "    labels.append(row[\"claude-3-haiku\"][0][0])\n",
    "    index_to_image_id[i] = k\n",
    "    image_id_to_index[k] = i\n",
    "    \n",
    "# initialize an instance of the Dataset\n",
    "ds = xDataset(dataset = dataset,\n",
    "              labels = labels,\n",
    "              connect_to_wordnet = False)\n",
    "\n",
    "\n",
    "def get_local_edits(image_id):\n",
    "    \"\"\"\n",
    "    Retrieves and processes local edits between two images in a dataset based on their semantic differences.\n",
    "\n",
    "    This function identifies the source image by its image_id, then finds the corresponding target image using\n",
    "    a domain-specific explanation system that prioritizes minimal semantic change (cost). The function retrieves \n",
    "    the edits between the source and target images, transforms these edits using the export_text_edits function,\n",
    "    and compiles the lists of added and removed objects based on these transformations.\n",
    "\n",
    "    Parameters:\n",
    "        image_id (int or str): The identifier for the source image in the dataset. This ID is used to find\n",
    "                               the corresponding source index, and subsequently, the target image and edits.\n",
    "\n",
    "    Returns:\n",
    "        tuple: A tuple containing three elements:\n",
    "            - source_image_id (str or int): The ID of the source image.\n",
    "            - added_objs (list): A list of objects added in the transition from the source to the target image.\n",
    "            - removed_objs (list): A list of objects removed in the transition from the source to the target image.\n",
    "\n",
    "    The function relies on several global data structures:\n",
    "        - image_id_to_index (dict): Maps image IDs to their respective indices in the dataset.\n",
    "        - index_to_image_id (dict): Maps dataset indices back to image IDs.\n",
    "        - ds (object): A dataset object that contains methods to explain differences between images and find edits.\n",
    "        - labels (list): A list of labels for the images in the dataset, used in the explanation process.\n",
    "    \"\"\"\n",
    "    source_index = image_id_to_index[image_id]\n",
    "    source_image_id = index_to_image_id[source_index]\n",
    "    objects_source = [dd for d in ds.dataset[source_index].concepts for dd in d if \".\" not in dd]\n",
    "\n",
    "    target_index, cost = ds.explain(ds.dataset[source_index], labels[source_index])\n",
    "    target_image_id = index_to_image_id[target_index]\n",
    "    cost, edits = ds.find_edits(ds.dataset[source_index], ds.dataset[target_index])\n",
    "    \n",
    "    edits = export_text_edits(edits)\n",
    "    added_objs = edits[\"additions\"] + [e for [_, e] in edits[\"transf\"]]\n",
    "    removed_objs = edits[\"removals\"] + [e for [e, _] in edits[\"transf\"]]\n",
    "    return objects_source, added_objs, removed_objs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pretrained Stable Diffusion Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edits import Edits\n",
    "import ast\n",
    "from editor import Editor\n",
    "import boto3\n",
    "from chat import Chat\n",
    "\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "editor = Editor(gradio_link)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the LVLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the aws runtime and model\n",
    "bedrock_runtime_client = boto3.client(\n",
    "    'bedrock-runtime',\n",
    "    aws_access_key_id= aws_access_key_id,\n",
    "    aws_secret_access_key= aws_secret_access_key, \n",
    "    region_name= region_name\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'office_cubicles'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "from collections import defaultdict, Counter\n",
    "\n",
    "# Select between CNN classifier or LVLM classifier\n",
    "# CNN classifier code\n",
    "# from places365_classifier import *\n",
    "# classifier = Classifier(\"resnet18\")\n",
    "\n",
    "with open(\"data/processed_places_categories.pickle\", \"rb\") as f:\n",
    "    processed_categories = pickle.load(f)\n",
    "    \n",
    "def generate_prompt(categories):\n",
    "    # Join the categories into a readable list format\n",
    "    categories_str = \", \".join(categories)\n",
    "    return categories_str\n",
    "\n",
    "str_categories = generate_prompt(processed_categories)\n",
    "\n",
    "def classify(filename):\n",
    "    model_id = \"anthropic.claude-3-5-sonnet-20241022-v2:0\"\n",
    "    classification_prompt = f\"\"\"\n",
    "Classify each image in their appropriate class according to the scene they depict. \n",
    "Valid classes are {str_categories} and only these, so you need to classify the images in one of these classes.\n",
    "Pay attention to the semantics that define each class.\n",
    "Return me only the label of the scene depicted and nothing else.\n",
    "\"\"\"\n",
    "    source_classes_analyze = defaultdict(list)\n",
    "    for ii in range(7):\n",
    "        chat = Chat(model_id, bedrock_runtime_client) # create a chat like openning a new chat in \n",
    "        chat.add_user_message_image(classification_prompt, filename) # add a user message with an image and a text prompt\n",
    "        answer_source = chat.generate().strip().lower()\n",
    "        source_classes_analyze = {filename: [answer_source]}\n",
    "        \n",
    "    source_classes_analyze_clean = defaultdict(list, {k: list(map(lambda x: x.replace('\\n', ''), v)) for k, v in source_classes_analyze.items()})\n",
    "    source_classes_analyze_filtered = {key: list(filter(lambda item: item in processed_categories, items)) for key, items in source_classes_analyze_clean.items()}\n",
    "    for key in source_classes_analyze_filtered:\n",
    "        if not source_classes_analyze_filtered[key]:  # Checks if the list is empty\n",
    "            source_classes_analyze_filtered[key] = source_classes_analyze_clean[key]\n",
    "            \n",
    "    source_classes_analyze = defaultdict(list, {k: (lambda v: Counter(v).most_common(1)[0][0])(v) for k, v in source_classes_analyze_filtered.items()})\n",
    "        \n",
    "    return source_classes_analyze[filename]\n",
    "            \n",
    "            \n",
    "\n",
    "classify(\"imgs/random/claude/11/source.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def global_explanations(source_image_path, orig_label):\n",
    "\n",
    "#     orig_label = classify(source_image_path)[0][0]\n",
    "\n",
    "    regional_dataset = []\n",
    "    regional_labels = []\n",
    "    for l, r in zip(labels, dataset):\n",
    "        if l == orig_label:\n",
    "            regional_dataset.append(r)\n",
    "            regional_labels.append(l)\n",
    "\n",
    "    gl = ds.global_explanation(regional_dataset, regional_labels)\n",
    "    return {k: v for k, v in gl.items() if \".\" not in k}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run the editor with the optimal edits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prompts import prompt_single_step, prompt_add_object, prompt_remove_object\n",
    "\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "def create_or_replace_dir(directory_name):\n",
    "    # Check if the directory already exists\n",
    "    if os.path.exists(directory_name):\n",
    "        # If it exists, remove it\n",
    "        shutil.rmtree(directory_name)\n",
    "    \n",
    "    # Create the new directory\n",
    "    os.makedirs(directory_name)\n",
    "\n",
    "def edit_global_edits(image_id):\n",
    "    \n",
    "    create_or_replace_dir(f\"imgs/random/claude-3-5-sonnet/global-local/{image_id}\")\n",
    "    source_image_path = f\"imgs/random/claude-3-5-sonnet/global-local/{image_id}/source.jpg\"\n",
    "    \n",
    "    url = data[image_id][\"url\"]\n",
    "    download_image_from_url(url, source_image_path)\n",
    "    steps = []\n",
    "    \n",
    "    objs, added_objs, removed_objs = get_local_edits(image_id)\n",
    "    \n",
    "    chat = Chat(model_id, bedrock_runtime_client)\n",
    "\n",
    "    logs = \"\"\n",
    "    excs, i = 0, 1\n",
    "    orig_label = classify(source_image_path)\n",
    "    global_edits = global_explanations(source_image_path, orig_label)\n",
    "    new_label = orig_label\n",
    "    logs += f\"Classification: {orig_label}\\n\"\n",
    "    \n",
    "    sorted_edits = {}\n",
    "    for e in added_objs + removed_objs:\n",
    "        if e in global_edits:\n",
    "            v = global_edits[e]\n",
    "        else:\n",
    "            if e in removed_objs:\n",
    "                v = -0.1\n",
    "            if e in added_objs:\n",
    "                v = 0.1\n",
    "        \n",
    "        sorted_edits[e] = v\n",
    "        \n",
    "    sorted_edits = [[k, v] for k, v in sorted(sorted_edits.items(), key=lambda item: abs(item[1]), reverse = True)]\n",
    "    \n",
    "    for o in global_edits:\n",
    "        if global_edits[o] == 0:\n",
    "            continue \n",
    "        if o not in added_objs + removed_objs:\n",
    "            sorted_edits.append([o, global_edits[o]])\n",
    "            \n",
    "    for [obj, v] in sorted_edits:\n",
    "        try:\n",
    "                \n",
    "            if  v <= 0:\n",
    "                if obj in added_objects:\n",
    "                    continue\n",
    "                if obj in objs:\n",
    "                    prompt = prompt_remove_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    background = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{background}\\n\"\n",
    "                    background = background.strip()\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, obj, background)\n",
    "                    logs += f\"\\n{['remove', obj, background]}\\n\" \n",
    "                    steps.append([\"remove\", obj, background])\n",
    "                    \n",
    "                    source_image_path = f\"imgs/random/claude-3-5-sonnet/global-local/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    new_label = classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "\n",
    "\n",
    "            elif v >= 0:    \n",
    "                if obj in removed_objs:\n",
    "                    continue\n",
    "                if obj not in objs: \n",
    "                    prompt = prompt_add_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    add = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{add}\\n\"\n",
    "                    add = add.strip()\n",
    "\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, add, obj)\n",
    "                    logs += f\"\\n{['add', obj, add]}\\n\" \n",
    "                    steps.append([\"add\", obj, add])\n",
    "\n",
    "                    source_image_path = f\"imgs/random/claude-3-5-sonnet/global-local/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    \n",
    "                    new_label = classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "                    \n",
    "\n",
    "        except Exception as e:\n",
    "            excs += 1\n",
    "            logs += f\"Exception: {e}\\n\"\n",
    "            if excs >= 5:\n",
    "                break\n",
    "                \n",
    "                \n",
    "        if (orig_label != new_label):\n",
    "            break\n",
    "\n",
    "\n",
    "    logs += f\"\\n\\n----\\n\\n{steps}\\n\\n----\\n\\n\"\n",
    "    with open(f\"imgs/random/claude-3-5-sonnet/global-local/{image_id}/logs.txt\", \"w\") as handle:\n",
    "        handle.write(logs)\n",
    "    return steps\n",
    "        \n",
    "    \n",
    "def classified_as(image_path, cl):\n",
    "    preds = classify(image_path)\n",
    "    if preds == cl:\n",
    "        return True\n",
    "    return False "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 78/499 [1:50:53<10:30:35, 89.87s/it] "
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for key in tqdm(data):\n",
    "    edit_global_edits(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
