{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "torch.cuda.set_device(0)\n",
    "import cv2\n",
    "import numpy as np\n",
    "from tools.run_infinity import *\n",
    "\n",
    "model_path='weights/infinity_2b_reg.pth'\n",
    "vae_path='weights/infinity_vae_d32_reg.pth'\n",
    "text_encoder_ckpt = 'weights/flan-t5-xl'\n",
    "args=argparse.Namespace(\n",
    "    pn='1M',\n",
    "    model_path=model_path,\n",
    "    cfg_insertion_layer=0,\n",
    "    vae_type=32,\n",
    "    vae_path=vae_path,\n",
    "    add_lvl_embeding_only_first_block=1,\n",
    "    use_bit_label=1,\n",
    "    model_type='infinity_2b',\n",
    "    rope2d_each_sa_layer=1,\n",
    "    rope2d_normalized_by_hw=2,\n",
    "    use_scale_schedule_embedding=0,\n",
    "    sampling_per_bits=1,\n",
    "    text_encoder_ckpt=text_encoder_ckpt,\n",
    "    text_channels=2048,\n",
    "    apply_spatial_patchify=0,\n",
    "    h_div_w_template=1.000,\n",
    "    use_flex_attn=0,\n",
    "    cache_dir='/dev/shm',\n",
    "    checkpoint_type='torch',\n",
    "    seed=0,\n",
    "    bf16=1,\n",
    "    save_file='tmp.jpg'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load text encoder\n",
    "text_tokenizer, text_encoder = load_tokenizer(t5_path=args.text_encoder_ckpt)\n",
    "# load vae\n",
    "vae = load_visual_tokenizer(args)\n",
    "# load infinity\n",
    "infinity = load_transformer(vae, args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"alien spaceship enterprise\"\"\"\n",
    "cfg = 3\n",
    "tau = 0.5\n",
    "h_div_w = 1/1 # aspect ratio, height:width\n",
    "seed = random.randint(0, 10000)\n",
    "enable_positive_prompt=0\n",
    "\n",
    "h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]\n",
    "scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']\n",
    "scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]\n",
    "generated_image = gen_one_img(\n",
    "    infinity,\n",
    "    vae,\n",
    "    text_tokenizer,\n",
    "    text_encoder,\n",
    "    prompt,\n",
    "    g_seed=seed,\n",
    "    gt_leak=0,\n",
    "    gt_ls_Bl=None,\n",
    "    cfg_list=cfg,\n",
    "    tau_list=tau,\n",
    "    scale_schedule=scale_schedule,\n",
    "    cfg_insertion_layer=[args.cfg_insertion_layer],\n",
    "    vae_type=args.vae_type,\n",
    "    sampling_per_bits=args.sampling_per_bits,\n",
    "    enable_positive_prompt=enable_positive_prompt,\n",
    ")\n",
    "args.save_file = 'ipynb_tmp.jpg'\n",
    "os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)\n",
    "cv2.imwrite(args.save_file, generated_image.cpu().numpy())\n",
    "print(f'Save to {osp.abspath(args.save_file)}')"
   ]
  }
 ],
 "metadata": {
  "fileId": "8ac263ab-b18c-41dc-b409-0fb0f32525f0",
  "filePath": "/mnt/bn/foundation-vision/hanjian.thu123/infinity/infinity/tools/interactive_infer.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
