{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8301596-d405-4d6d-a86d-4c0060da77f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.utils import *\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfac7476-6b52-4d49-b044-4bc3a84fe849",
   "metadata": {},
   "outputs": [],
   "source": [
    "esd_path = 'models/esd-vangogh_from_vangogh-xattn_1-epochs_200.pt'\n",
    "train_method = 'xattn' ## REMEMBER: please use the same train_method you used for training (it is present in the saved name)\n",
    "\n",
    "diffuser = StableDiffuser(scheduler='DDIM').to('cuda')\n",
    "\n",
    "finetuner = FineTunedModel(diffuser, train_method=train_method)\n",
    "finetuner.load_state_dict(torch.load(esd_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77386f14-35a7-43d3-bbc9-c6fede9c8fca",
   "metadata": {},
   "source": [
    "## Original Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4cb0941-24bb-41af-991d-a271a0406b03",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "images = diffuser(\"A vase of vibrant flowers, in the style of Van Gogh's still lifes\",\n",
    "         img_size=512,\n",
    "         n_steps=50,\n",
    "         n_imgs=1,\n",
    "         generator=torch.Generator().manual_seed(seed),\n",
    "         guidance_scale=6\n",
    "         )[0][0]\n",
    "images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48c7af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "images = diffuser(\"A still life featuring bold contrasts between light and shadow, and dramatic use of color, reminiscent of Caravaggio's paintings\",\n",
    "         img_size=512,\n",
    "         n_steps=50,\n",
    "         n_imgs=1,\n",
    "         generator=torch.manual_seed(42),\n",
    "         guidance_scale=6\n",
    "         )[0][0]\n",
    "images\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bb2d3e7-e8be-4c9f-ad40-2396d02d3752",
   "metadata": {},
   "source": [
    "## Erased Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e84c33-93ca-4e88-8399-3d83279f1b5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with finetuner:\n",
    "    images = diffuser(\"The play of light and shadow in Rembrandt's iconic Night Watch\",\n",
    "             img_size=512,\n",
    "             n_steps=50,\n",
    "             n_imgs=1,\n",
    "             generator=torch.Generator().manual_seed(seed),\n",
    "             guidance_scale=7.5\n",
    "             )[0][0]\n",
    "images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "109e889c-d39f-4252-8413-a876686c0c8b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
