{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "def download_image_from_url(url, filename):\n",
    "    img_data = requests.get(url).content\n",
    "    with open(filename, 'wb') as handler:\n",
    "        handler.write(img_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Hyperparameters of the experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define Hyperparameters of the experiment\n",
    "gradio_link = \"https://.....gradio.live\"\n",
    "\n",
    "# about the LVLM using AWS Bedrock\n",
    "aws_access_key_id= \"yours_aws_access_key_id\",\n",
    "aws_secret_access_key= \"yours_aws_secret_access_key\", \n",
    "region_name= \"yours_region_name\"\n",
    "model_id = \"anthropic.claude-3-5-sonnet-20241022-v2:0\" # model_id of the LVLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pretrained Stable Diffusion Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cece\n",
    "from cece.queries import *\n",
    "from cece.refine import *\n",
    "from cece.wordnet import *\n",
    "\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"vg_data_random.pickle\", \"rb\") as handle:\n",
    "    data = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def export_text_edits(edits):\n",
    "    transf = []\n",
    "    for e1, e2 in edits[\"transf\"]:\n",
    "        ee1, ee2 = None, None\n",
    "        for e in e1:\n",
    "            if \".\" not in e:\n",
    "                ee1 = e\n",
    "                break\n",
    "                \n",
    "        for e in e2:\n",
    "            if \".\" not in e:\n",
    "                ee2 = e\n",
    "                break\n",
    "        transf.append([ee1, ee2])\n",
    "        \n",
    "    return {\n",
    "        \"additions\": [ee for e in edits[\"additions\"] for ee in e if \".\" not in ee],\n",
    "        \"removals\": [ee for e in edits[\"removals\"] for ee in e if \".\" not in ee],\n",
    "        \"transf\": transf\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cece.xDataset import *\n",
    "from cece.xDataset import createMSQ\n",
    "\n",
    "dataset = []\n",
    "labels = []\n",
    "index_to_image_id = {}\n",
    "image_id_to_index = {}\n",
    "for i, (k, row) in enumerate(data.items()):\n",
    "    msq = []\n",
    "    for o in row[\"objects\"]:\n",
    "        msq.append (connect_term_to_wordnet(o).union([o.split(\".\")[0]]))\n",
    "    \n",
    "    dataset.append(msq)\n",
    "    labels.append(row[\"claude-3-5-sonnet\"][0][0])\n",
    "    index_to_image_id[i] = k\n",
    "    image_id_to_index[k] = i\n",
    "    \n",
    "# initialize an instance of the Dataset\n",
    "ds = xDataset(dataset = dataset,\n",
    "              labels = labels,\n",
    "              connect_to_wordnet = False)\n",
    "\n",
    "# source_index = 1\n",
    "\n",
    "# source_image_id = index_to_image_id[source_index]\n",
    "# objects_source = [dd for d in ds.dataset[source_index].concepts for dd in d if \".\" not in dd]\n",
    "\n",
    "# target_index, cost = ds.explain(ds.dataset[source_index], labels[source_index])\n",
    "# target_image_id = index_to_image_id[target_index]\n",
    "# cost, edits = ds.find_edits(ds.dataset[source_index], ds.dataset[target_index])\n",
    "\n",
    "def get_local_edits(image_id):\n",
    "    source_index = image_id_to_index[image_id]\n",
    "    source_image_id = index_to_image_id[source_index]\n",
    "    objects_source = [dd for d in ds.dataset[source_index].concepts for dd in d if \".\" not in dd]\n",
    "\n",
    "    target_index, cost = ds.explain(ds.dataset[source_index], labels[source_index])\n",
    "    target_image_id = index_to_image_id[target_index]\n",
    "    cost, edits = ds.find_edits(ds.dataset[source_index], ds.dataset[target_index])\n",
    "    \n",
    "    edits = export_text_edits(edits)\n",
    "    added_objs = edits[\"additions\"] + [e for [_, e] in edits[\"transf\"]]\n",
    "    removed_objs = edits[\"removals\"] + [e for [e, _] in edits[\"transf\"]]\n",
    "    return objects_source, added_objs, removed_objs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Pretrained Stable Diffusion Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from edits import Edits\n",
    "import ast\n",
    "from editor import Editor\n",
    "import boto3\n",
    "from chat import Chat\n",
    "\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "editor = Editor(gradio_link)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LVLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the aws runtime and model\n",
    "bedrock_runtime_client = boto3.client(\n",
    "    'bedrock-runtime',\n",
    "    aws_access_key_id= \"AKIAZI2LGVAQHN34QQXW\", #\"AKIAZI2LGVAQJ2HN3WPF\",\n",
    "    aws_secret_access_key= \"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\", #\"3vi3F+iaGkePZy00enO/oPRU66e3mlAHIULjPtl9\",\n",
    "    region_name=\"us-west-2\"\n",
    ")\n",
    "\n",
    "model_id = \"anthropic.claude-3-5-sonnet-20241022-v2:0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'office_cubicles'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "\n",
    "# from places365_classifier import *\n",
    "# classifier = Classifier(\"resnet18\")\n",
    "\n",
    "with open(\"processed_places_categories.pickle\", \"rb\") as f:\n",
    "    processed_categories = pickle.load(f)\n",
    "    \n",
    "def generate_prompt(categories):\n",
    "    # Join the categories into a readable list format\n",
    "    categories_str = \", \".join(categories)\n",
    "    return categories_str\n",
    "\n",
    "str_categories = generate_prompt(processed_categories)\n",
    "\n",
    "def classify(filename):\n",
    "    model_id = \"anthropic.claude-3-5-sonnet-20241022-v2:0\"\n",
    "    classification_prompt = f\"\"\"\n",
    "Classify each image in their appropriate class according to the scene they depict. \n",
    "Valid classes are {str_categories} and only these, so you need to classify the images in one of these classes.\n",
    "Pay attention to the semantics that define each class.\n",
    "Return me only the label of the scene depicted and nothing else.\n",
    "\"\"\"\n",
    "    source_classes_analyze = defaultdict(list)\n",
    "    for ii in range(7):\n",
    "        chat = Chat(model_id, bedrock_runtime_client) # create a chat like openning a new chat in \n",
    "        chat.add_user_message_image(classification_prompt, filename) # add a user message with an image and a text prompt\n",
    "        answer_source = chat.generate().strip().lower()\n",
    "        source_classes_analyze = {filename: [answer_source]}\n",
    "        \n",
    "    source_classes_analyze_clean = defaultdict(list, {k: list(map(lambda x: x.replace('\\n', ''), v)) for k, v in source_classes_analyze.items()})\n",
    "    source_classes_analyze_filtered = {key: list(filter(lambda item: item in processed_categories, items)) for key, items in source_classes_analyze_clean.items()}\n",
    "    for key in source_classes_analyze_filtered:\n",
    "        if not source_classes_analyze_filtered[key]:  # Checks if the list is empty\n",
    "            source_classes_analyze_filtered[key] = source_classes_analyze_clean[key]\n",
    "            \n",
    "    source_classes_analyze = defaultdict(list, {k: (lambda v: Counter(v).most_common(1)[0][0])(v) for k, v in source_classes_analyze_filtered.items()})\n",
    "        \n",
    "    return source_classes_analyze[filename]\n",
    "            \n",
    "            \n",
    "\n",
    "classify(\"imgs/random/claude/11/source.jpg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def global_explanations(source_image_path, orig_label):\n",
    "\n",
    "    regional_dataset = []\n",
    "    regional_labels = []\n",
    "    for l, r in zip(labels, dataset):\n",
    "        if l == orig_label:\n",
    "            regional_dataset.append(r)\n",
    "            regional_labels.append(l)\n",
    "\n",
    "    gl = ds.global_explanation(regional_dataset, regional_labels)\n",
    "    return {k: v for k, v in gl.items() if \".\" not in k}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Claude decide the edits!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prompts import prompt_single_step, prompt_add_object, prompt_remove_object\n",
    "\n",
    "import os\n",
    "import shutil\n",
    "\n",
    "def create_or_replace_dir(directory_name):\n",
    "    # Check if the directory already exists\n",
    "    if os.path.exists(directory_name):\n",
    "        # If it exists, remove it\n",
    "        shutil.rmtree(directory_name)\n",
    "    \n",
    "    # Create the new directory\n",
    "    os.makedirs(directory_name)\n",
    "\n",
    "def edit_global_edits(image_id):\n",
    "    \n",
    "    create_or_replace_dir(f\"imgs/random/claude-3-5-sonnet/global/{image_id}\")\n",
    "    source_image_path = f\"imgs/random/claude-3-5-sonnet/global/{image_id}/source.jpg\"\n",
    "    \n",
    "    url = data[image_id][\"url\"]\n",
    "    download_image_from_url(url, source_image_path)\n",
    "    steps = []\n",
    "    \n",
    "    objs, added_objs, removed_objs = get_local_edits(image_id)\n",
    "    \n",
    "    chat = Chat(model_id, bedrock_runtime_client)\n",
    "\n",
    "    logs = \"\"\n",
    "    excs, i = 0, 1\n",
    "    orig_label = classify(source_image_path)\n",
    "    new_label = orig_label\n",
    "    global_edits = global_explanations(source_image_path, orig_label)\n",
    "    logs += f\"Classification: {orig_label}\\n\"\n",
    "    for obj, v in global_edits.items():\n",
    "        try:\n",
    "            \n",
    "            if  k <= 0:\n",
    "                \n",
    "                if obj in objs:\n",
    "                    prompt = prompt_remove_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    background = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{background}\\n\"\n",
    "                    background = background.strip()\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, obj, background)\n",
    "                    logs += f\"\\n{['remove', obj, background]}\\n\" \n",
    "                    steps.append([\"remove\", obj, background])\n",
    "                    \n",
    "                    source_image_path = f\"imgs/random/claude-3-5-sonnet/global/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    new_label = classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "\n",
    "\n",
    "            elif k > 0:    \n",
    "                \n",
    "                if obj not in objs: \n",
    "                    prompt = prompt_add_object(obj)\n",
    "                    chat.add_user_message_image(prompt, source_image_path) # add a user message with an image and a text prompt\n",
    "                    add = chat.generate()\n",
    "                    logs += f\"\\n----\\nOutput LVLM: {i}\\n{add}\\n\"\n",
    "                    add = add.strip()\n",
    "\n",
    "\n",
    "                    new_image, mask = editor.replacer(source_image_path, add, obj)\n",
    "                    logs += f\"\\n{['add', obj, add]}\\n\" \n",
    "                    steps.append([\"add\", obj, add])\n",
    "\n",
    "                    source_image_path = f\"imgs/random/claude-3-5-sonnet/global/{image_id}/step_{i}.jpg\"\n",
    "                    new_image.save(source_image_path)\n",
    "                    i += 1\n",
    "                    new_label = classify(source_image_path)\n",
    "                    logs += f\"Classification: {new_label}\\n\"\n",
    "\n",
    "        except Exception as e:\n",
    "            excs += 1\n",
    "            logs += f\"Exception: {e}\\n\"\n",
    "            if excs >= 5:\n",
    "                break\n",
    "                \n",
    "                \n",
    "        if (new_label != orig_label):\n",
    "            break\n",
    "\n",
    "\n",
    "    logs += f\"\\n\\n----\\n\\n{steps}\\n\\n----\\n\\n\"\n",
    "    with open(f\"imgs/random/claude-3-5-sonnet/global/{image_id}/logs.txt\", \"w\") as handle:\n",
    "        handle.write(logs)\n",
    "    return steps\n",
    "        \n",
    "    \n",
    "def classified_as(image_path, cl):\n",
    "    preds = classifier.classify(image_path)\n",
    "    if preds == cl:\n",
    "        return True\n",
    "    return False "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for key in tqdm(data):\n",
    "    edit_global_edits(key)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv_3.10.12",
   "language": "python",
   "name": "myenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
