{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbf4cab5-6bf8-40d9-9f1a-4ffd6d128cd2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline\n",
    "import torch\n",
    "import os\n",
    "import gc\n",
    "from PIL import Image\n",
    "\n",
    "from src.linfusion import LinFusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa3e7e1-0413-4668-882f-e91bbf20512b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pipeline = StableDiffusionPipeline.from_pretrained(\n",
    "    \"Lykon/dreamshaper-8\", torch_dtype=torch.float16, variant=\"fp16\"\n",
    ").to(torch.device(\"cuda\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "825d6621-4cce-4f44-aec3-9d7cb68f75a9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "linfusion = LinFusion.construct_for(pipeline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee030a1-fe18-4457-a842-b088c3c5d48c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if not os.path.exists('results'):\n",
    "    os.mkdir('results')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "831a565e-009e-4cff-a903-5eafc2321724",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "generator = torch.manual_seed(3)\n",
    "image = pipeline(\n",
    "    \"A photo of the Milky Way galaxy\",\n",
    "    height=512,\n",
    "    width=1024,\n",
    "    generator=generator\n",
    ").images[0]\n",
    "image.save('results/output_1k.jpg')\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3f4ae07-d6a3-425f-940f-6d87000abcb5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(\n",
    "    \"Lykon/dreamshaper-8\", torch_dtype=torch.float16, variant=\"fp16\"\n",
    ").to(torch.device(\"cuda\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d23376-cc8f-4d03-9aa2-0a16d72a259d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path=\"Yuanshi/LinFusion-1-5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e3f3e0-7b8a-404c-b3c1-ed0dfb4c69e9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "init_image = image.resize((2048, 1024))\n",
    "generator = torch.manual_seed(3)\n",
    "image = pipeline(\n",
    "    \"A photo of the Milky Way galaxy\",\n",
    "    image=init_image, strength=0.4, generator=generator).images[0]\n",
    "image.save('results/output_2k.jpg')\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f373053d-277e-4710-8a1a-9cf631c81229",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pipeline.enable_vae_tiling()\n",
    "pipeline.vae.tile_sample_min_size = 2048\n",
    "pipeline.vae.tile_latent_min_size = 2048 // 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ca97038-9e81-465e-ab91-a055a926d87e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "init_image = image.resize((4096, 2048))\n",
    "generator = torch.manual_seed(3)\n",
    "image = pipeline(\n",
    "    \"A photo of the Milky Way galaxy\",\n",
    "    image=init_image, strength=0.3, generator=generator).images[0]\n",
    "image.save('results/output_4k.jpg')\n",
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50350beb-5045-4951-a227-fb2269256b2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "init_image = image.resize((8192, 4096))\n",
    "generator = torch.manual_seed(3)\n",
    "image = pipeline(\n",
    "    \"A photo of the Milky Way galaxy\",\n",
    "    image=init_image, strength=0.2, generator=generator).images[0]\n",
    "image.save('results/output_8k.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b65d6254-989c-4685-bf4b-39288b14a3c9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81bd0a32-e1c9-464d-84d9-42bb83d74f46",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "init_image = image.resize((16384, 8192))\n",
    "generator = torch.manual_seed(3)\n",
    "image = pipeline(\n",
    "    \"A photo of the Milky Way galaxy\",\n",
    "    image=init_image, strength=0.1, generator=generator).images[0]\n",
    "image.save('results/output_16k.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee57bcbf-d248-4d1b-b9cf-f7f3fb2b2ccc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
