{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "torch.cuda.set_device(2)\n",
    "import cv2\n",
    "import numpy as np\n",
    "from tools.run_infinity import *\n",
    "\n",
    "model_path='weights/infinity_8b_weights'\n",
    "vae_path='weights/infinity_vae_d56_f8_14_patchify.pth'\n",
    "text_encoder_ckpt = 'weights/flan-t5-xl-official'\n",
    "args=argparse.Namespace(\n",
    "    pn='1M',\n",
    "    model_path=model_path,\n",
    "    cfg_insertion_layer=0,\n",
    "    vae_type=14,\n",
    "    vae_path=vae_path,\n",
    "    add_lvl_embeding_only_first_block=1,\n",
    "    use_bit_label=1,\n",
    "    model_type='infinity_8b',\n",
    "    rope2d_each_sa_layer=1,\n",
    "    rope2d_normalized_by_hw=2,\n",
    "    use_scale_schedule_embedding=0,\n",
    "    sampling_per_bits=1,\n",
    "    text_encoder_ckpt=text_encoder_ckpt,\n",
    "    text_channels=2048,\n",
    "    apply_spatial_patchify=1,\n",
    "    h_div_w_template=1.000,\n",
    "    use_flex_attn=0,\n",
    "    cache_dir='/dev/shm',\n",
    "    checkpoint_type='torch_shard',\n",
    "    seed=0,\n",
    "    bf16=1,\n",
    "    save_file='tmp.jpg'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Loading tokenizer and text encoder]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f68ce998b1546f185e6263884b382ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Loading Infinity]\n",
      "self.codebook_dim: 56, self.add_lvl_embeding_only_first_block: 1,             self.use_bit_label: 1, self.rope2d_each_sa_layer: 1, self.rope2d_normalized_by_hw: 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:179: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
      "  with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "self.num_blocks_in_a_chunk=5, depth=40, block_chunks=8\n",
      "\n",
      "[constructor]  ==== customized_flash_attn=False (using_flash=0/40), fused_mlp=False (fused_mlp=0/40) ==== \n",
      "    [Infinity config ] embed_dim=3584, num_heads=28, depth=40, mlp_ratio=4, swiglu=False num_blocks_in_a_chunk=5\n",
      "    [drop ratios] drop_rate=0.0, drop_path_rate=0.1 (tensor([0.0000, 0.0026, 0.0051, 0.0077, 0.0103, 0.0128, 0.0154, 0.0179, 0.0205,\n",
      "        0.0231, 0.0256, 0.0282, 0.0308, 0.0333, 0.0359, 0.0385, 0.0410, 0.0436,\n",
      "        0.0462, 0.0487, 0.0513, 0.0538, 0.0564, 0.0590, 0.0615, 0.0641, 0.0667,\n",
      "        0.0692, 0.0718, 0.0744, 0.0769, 0.0795, 0.0821, 0.0846, 0.0872, 0.0897,\n",
      "        0.0923, 0.0949, 0.0974, 0.1000]))\n",
      "\n",
      "[you selected Infinity with model_kwargs={'depth': 40, 'embed_dim': 3584, 'num_heads': 28, 'drop_path_rate': 0.1, 'mlp_ratio': 4, 'block_chunks': 8}] model size: 8.38B, bf16=1\n",
      "[Load Infinity weights]\n"
     ]
    }
   ],
   "source": [
    "# load text encoder\n",
    "text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)\n",
    "# load vae\n",
    "vae = load_visual_tokenizer(args)\n",
    "# load infinity\n",
    "infinity = load_transformer(vae, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompt=a cat holds a board with the text 'diffusion is dead'\n",
      "cfg: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], tau: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/run_infinity.py:112: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
      "  with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):\n",
      "/mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/infinity/models/basic.py:495: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
      "  with torch.cuda.amp.autocast(enabled=False):    # disable half precision\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cost: 1.7465496063232422, infinity cost=1.7265434265136719\n",
      "Save to /mnt/bn/foundation-vision/hanjian.thu123/infinity/pub_release/Infinity/tools/ipynb_tmp.jpg\n"
     ]
    }
   ],
   "source": [
    "prompt = \"\"\"a cat holds a board with the text 'diffusion is dead'\"\"\"\n",
    "cfg = 3\n",
    "tau = 1.0\n",
    "h_div_w = 1/1 # aspect ratio, height:width\n",
    "seed = random.randint(0, 10000)\n",
    "enable_positive_prompt=0\n",
    "\n",
    "h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]\n",
    "scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']\n",
    "scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]\n",
    "generated_image = gen_one_img(\n",
    "    infinity,\n",
    "    vae,\n",
    "    text_tokenizer,\n",
    "    text_encoder,\n",
    "    prompt,\n",
    "    g_seed=seed,\n",
    "    gt_leak=0,\n",
    "    gt_ls_Bl=None,\n",
    "    cfg_list=cfg,\n",
    "    tau_list=tau,\n",
    "    scale_schedule=scale_schedule,\n",
    "    cfg_insertion_layer=[args.cfg_insertion_layer],\n",
    "    vae_type=args.vae_type,\n",
    "    sampling_per_bits=args.sampling_per_bits,\n",
    "    enable_positive_prompt=enable_positive_prompt,\n",
    ")\n",
    "args.save_file = 'ipynb_tmp.jpg'\n",
    "os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)\n",
    "cv2.imwrite(args.save_file, generated_image.cpu().numpy())\n",
    "print(f'Save to {osp.abspath(args.save_file)}')"
   ]
  }
 ],
 "metadata": {
  "fileId": "8ac263ab-b18c-41dc-b409-0fb0f32525f0",
  "filePath": "/mnt/bn/foundation-vision/hanjian.thu123/infinity/infinity/tools/interactive_infer.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
