{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d90bd3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# For licensing see accompanying LICENSE file.\n",
    "# Copyright (C) 2024 Apple Inc. All Rights Reserved.\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39465df4-3ac7-4340-b71d-02c11c78d1be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, pickle, io, base64, json\n",
    "\n",
    "from glob import glob\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import torch as T\n",
    "import transformers\n",
    "\n",
    "from llava.conversation import conv_templates\n",
    "from llava.model import *\n",
    "\n",
    "def f2b(f):\n",
    "    b = io.BytesIO()\n",
    "    f.save(b, format='JPEG')\n",
    "    b = str(base64.b64encode(b.getvalue()))[2:-1]\n",
    "    return b\n",
    "def b2f(b):\n",
    "    return Image.open(io.BytesIO(base64.b64decode(b))).convert('RGB')\n",
    "def crop_resize(f, sz=512):\n",
    "    w, h = f.size\n",
    "    if w>h:\n",
    "        p = (w-h)//2\n",
    "        f = f.crop([p, 0, p+h, h])\n",
    "    elif h>w:\n",
    "        p = (h-w)//2\n",
    "        f = f.crop([0, p, w, p+w])\n",
    "    f = f.resize([sz, sz])\n",
    "    return f\n",
    "def remove_alter(s):  # hack expressive instruction\n",
    "    if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()\n",
    "    if '</s>' in s: s = s[:s.index('</s>')].strip()\n",
    "    if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]\n",
    "    if '[IMG0]' in s: s = s[:s.index('[IMG0]')]\n",
    "    s = '.'.join([s.strip() for s in s.split('.')[:2]])\n",
    "    if s[-1]!='.': s += '.'\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7004c5ab-7ccc-4914-b1d9-0314ebfd5a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_IMAGE_TOKEN = '<image>'\n",
    "DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'\n",
    "DEFAULT_IM_START_TOKEN = '<im_start>'\n",
    "DEFAULT_IM_END_TOKEN = '<im_end>'\n",
    "\n",
    "MODEL_NAME = './_ckpt/LLaVA-7B-v1'\n",
    "model_name = os.path.expanduser(MODEL_NAME)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)\n",
    "model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()\n",
    "image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)\n",
    "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
    "if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
    "\n",
    "vision_tower = model.get_model().vision_tower[0]\n",
    "vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()\n",
    "model.get_model().vision_tower[0] = vision_tower\n",
    "vision_config = vision_tower.config\n",
    "vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n",
    "vision_config.use_im_start_end = mm_use_im_start_end\n",
    "if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n",
    "image_token_len = (vision_config.image_size//vision_config.patch_size)**2\n",
    "\n",
    "_ = model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e479793-f2f9-45b6-a05b-9f5b0d09e396",
   "metadata": {},
   "outputs": [],
   "source": [
    "summer = transformers.pipeline('summarization', 'jordiclive/flan-t5-11b-summarizer-filtered', torch_dtype=T.bfloat16, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ac2e16-32a8-45ea-acc9-37b28a0a9681",
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl, tsv, ei = {'task': []}, open('./_data/ipr2pr.tsv', 'w'), {}\n",
    "\n",
    "lst = glob('_data/*/prompt.json')\n",
    "for file in tqdm(lst):\n",
    "    prompt = json.load(open(file, 'r'))\n",
    "    txt = prompt['edit']\n",
    "\n",
    "    txt = \"what will this image be like if '%s'  (in a short paragraph)\"%(txt)\n",
    "    txt = txt+'\\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN\n",
    "    conv = conv_templates['vicuna_v1_1'].copy()\n",
    "    conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)\n",
    "    txt = conv.get_prompt()\n",
    "    txt = tokenizer(txt)\n",
    "    txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])\n",
    "    \n",
    "    for img in glob('/'.join(file.split('/')[:-1])+'/*_0.jpg'):\n",
    "        item = file.split('/')[-2]+'_'+img.split('/')[-1].replace('.jpg', '')\n",
    "        inp, ans = Image.open(img).convert('RGB'), Image.open(img.replace('_0.jpg', '_1.jpg')).convert('RGB')\n",
    "        \n",
    "        img = image_processor.preprocess(inp, return_tensors='pt')['pixel_values'][0]\n",
    "        with T.inference_mode():\n",
    "            out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(), \n",
    "                                 do_sample=False, max_new_tokens=1024)[0].tolist()\n",
    "            \n",
    "            out = remove_alter(tokenizer.decode(out))\n",
    "            res = summer(['summarize the following paragraph in 32 words:\\n\\n%s'%(out)], num_beams=5, min_length=5, max_length=64, \n",
    "                         do_sample=False, no_repeat_ngram_size=3, truncation=True)[0]['summary_text']\n",
    "\n",
    "        pkl['task'].append([{'input': item, 'answer': item.replace('_0', '_1'), 'instruction': prompt['edit'], 'lineidx': tsv.tell()}])\n",
    "        tsv.write('%s\\t%s\\n'%(f2b(inp), f2b(ans)))\n",
    "        ei[item] = {'instruction': prompt['edit'], 'expressive': res}\n",
    "\n",
    "pickle.dump(pkl, open('./_data/ipr2pr.pkl', 'wb'))\n",
    "tsv.flush(), tsv.close()\n",
    "json.dump(ei, open('./_data/ipr2pr_expressive.json', 'w'), indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c36ea959-ff15-4ebd-848f-13de020886a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "pkl, tsv, ei = pickle.load(open('./_data/ipr2pr.pkl', 'rb')), open('./_data/ipr2pr.tsv', 'r'), json.load(open('./_data/ipr2pr_expressive.json', 'r'))\n",
    "for task in pkl['task']:\n",
    "    task = task[0]\n",
    "    tsv.seek(task['lineidx'])\n",
    "    b = tsv.readline().strip().split('\\t')\n",
    "    print(task)\n",
    "    display(b2f(b[0])), display(b2f(b[1]))\n",
    "    print(ei[task['input']])\n",
    "    print('\\n-----\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
