{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05c1c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# For licensing see accompanying LICENSE file.\n",
    "# Copyright (C) 2024 Apple Inc. All Rights Reserved.\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1c35b9f-06d2-4bc6-a13a-5c68730530c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import torch as T\n",
    "import transformers, diffusers\n",
    "\n",
    "from llava.conversation import conv_templates\n",
    "from llava.model import *\n",
    "\n",
    "def crop_resize(f, sz=512):\n",
    "    w, h = f.size\n",
    "    if w>h:\n",
    "        p = (w-h)//2\n",
    "        f = f.crop([p, 0, p+h, h])\n",
    "    elif h>w:\n",
    "        p = (h-w)//2\n",
    "        f = f.crop([0, p, w, p+w])\n",
    "    f = f.resize([sz, sz])\n",
    "    return f\n",
    "def remove_alter(s):  # hack expressive instruction\n",
    "    if 'ASSISTANT:' in s: s = s[s.index('ASSISTANT:')+10:].strip()\n",
    "    if '</s>' in s: s = s[:s.index('</s>')].strip()\n",
    "    if 'alternative' in s.lower(): s = s[:s.lower().index('alternative')]\n",
    "    if '[IMG0]' in s: s = s[:s.index('[IMG0]')]\n",
    "    s = '.'.join([s.strip() for s in s.split('.')[:2]])\n",
    "    if s[-1]!='.': s += '.'\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df365293-3856-4c98-87dc-3569cab81700",
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_IMAGE_TOKEN = '<image>'\n",
    "DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'\n",
    "DEFAULT_IM_START_TOKEN = '<im_start>'\n",
    "DEFAULT_IM_END_TOKEN = '<im_end>'\n",
    "PATH_LLAVA = './_ckpt/LLaVA-7B-v1'\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)\n",
    "model = LlavaLlamaForCausalLM.from_pretrained(PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()\n",
    "image_processor = transformers.CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=T.float16)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]', '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)\n",
    "model.resize_token_embeddings(len(tokenizer))\n",
    "ckpt = T.load('./_ckpt/mgie_7b/mllm.pt', map_location='cpu')\n",
    "model.load_state_dict(ckpt, strict=False)\n",
    "\n",
    "mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)\n",
    "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
    "if mm_use_im_start_end: tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
    "\n",
    "vision_tower = model.get_model().vision_tower[0]\n",
    "vision_tower = transformers.CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()\n",
    "model.get_model().vision_tower[0] = vision_tower\n",
    "vision_config = vision_tower.config\n",
    "vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n",
    "vision_config.use_im_start_end = mm_use_im_start_end\n",
    "if mm_use_im_start_end: vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n",
    "image_token_len = (vision_config.image_size//vision_config.patch_size)**2\n",
    "\n",
    "_ = model.eval()\n",
    "EMB = ckpt['emb'].cuda()\n",
    "with T.inference_mode(): NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)\n",
    "print('NULL:', NULL.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86c7224e-ac5b-4395-8e92-908aa4949b7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained('timbrooks/instruct-pix2pix', torch_dtype=T.float16, safety_checker=None).to('cuda')\n",
    "pipe.set_progress_bar_config(disable=True)\n",
    "pipe.unet.load_state_dict(T.load('./_ckpt/mgie_7b/unet.pt', map_location='cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb920c1a-83fa-4746-a74a-0c72566acb73",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 13331\n",
    "\n",
    "ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion', \n",
    "       'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast', \n",
    "       'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out', \n",
    "       'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb', \n",
    "       'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']\n",
    "for i in tqdm(range(len(ins))):\n",
    "    img, txt = Image.open('_input/%d.jpg'%(i)).convert('RGB'), ins[i]\n",
    "    \n",
    "    img = image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0]\n",
    "    txt = \"what will this image be like if '%s'\"%(txt)\n",
    "    txt = txt+'\\n'+DEFAULT_IM_START_TOKEN+DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN\n",
    "    conv = conv_templates['vicuna_v1_1'].copy()\n",
    "    conv.append_message(conv.roles[0], txt), conv.append_message(conv.roles[1], None)\n",
    "    txt = conv.get_prompt()\n",
    "    txt = tokenizer(txt)\n",
    "    txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(txt['attention_mask'])\n",
    "    \n",
    "    with T.inference_mode():\n",
    "        out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(), \n",
    "                             do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3, \n",
    "                             return_dict_in_generate=True, output_hidden_states=True)\n",
    "        out, hid = out['sequences'][0].tolist(), T.cat([x[-1] for x in out['hidden_states']], dim=1)[0]\n",
    "        \n",
    "        p = min(out.index(32003)-1 if 32003 in out else len(hid)-9, len(hid)-9)\n",
    "        hid = hid[p:p+8]\n",
    "\n",
    "        out = remove_alter(tokenizer.decode(out))\n",
    "        emb = model.edit_head(hid.unsqueeze(dim=0), EMB)\n",
    "        res = pipe(image=Image.open('_input/%d.jpg'%(i)).convert('RGB'), prompt_embeds=emb, negative_prompt_embeds=NULL, generator=T.Generator(device='cuda').manual_seed(SEED)).images[0]\n",
    "    \n",
    "    display(Image.open('_input/%d.jpg'%(i)).convert('RGB')), print(ins[i])\n",
    "    print('\\n')\n",
    "    print(out), display(res)\n",
    "    print('\\n\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
