{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "60013ef6-dffb-41c0-9de7-0c99cc4cdef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def load_state_dict(model, state_dict, prefix='', ignore_missing=\"relative_position_index\"):\n",
    "    missing_keys = []\n",
    "    unexpected_keys = []\n",
    "    error_msgs = []\n",
    "    # copy state_dict so _load_from_state_dict can modify it\n",
    "    metadata = getattr(state_dict, '_metadata', None)\n",
    "    state_dict = state_dict.copy()\n",
    "    if metadata is not None:\n",
    "        state_dict._metadata = metadata\n",
    "\n",
    "    def load(module, prefix=''):\n",
    "        local_metadata = {} if metadata is None else metadata.get(\n",
    "            prefix[:-1], {})\n",
    "        module._load_from_state_dict(\n",
    "            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)\n",
    "        for name, child in module._modules.items():\n",
    "            if child is not None:\n",
    "                load(child, prefix + name + '.')\n",
    "\n",
    "    load(model, prefix=prefix)\n",
    "\n",
    "    warn_missing_keys = []\n",
    "    ignore_missing_keys = []\n",
    "    for key in missing_keys:\n",
    "        keep_flag = True\n",
    "        for ignore_key in ignore_missing.split('|'):\n",
    "            if ignore_key in key:\n",
    "                keep_flag = False\n",
    "                break\n",
    "        if keep_flag:\n",
    "            warn_missing_keys.append(key)\n",
    "        else:\n",
    "            ignore_missing_keys.append(key)\n",
    "\n",
    "    missing_keys = warn_missing_keys\n",
    "\n",
    "    if len(missing_keys) > 0:\n",
    "        print(\"Weights of {} not initialized from pretrained model: {}\".format(\n",
    "            model.__class__.__name__, missing_keys))\n",
    "    if len(unexpected_keys) > 0:\n",
    "        print(\"Weights from pretrained model not used in {}: {}\".format(\n",
    "            model.__class__.__name__, unexpected_keys))\n",
    "    if len(ignore_missing_keys) > 0:\n",
    "        print(\"Ignored weights of {} not initialized from pretrained model: {}\".format(\n",
    "            model.__class__.__name__, ignore_missing_keys))\n",
    "    if len(error_msgs) > 0:\n",
    "        print('\\n'.join(error_msgs))\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "eee7c83d-1bbc-45d6-8e5b-0ef52399e304",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4015f37485574b6a96fe5e46d9e1326e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4add4e030a045538b5e667d0955b9e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config[\"id2label\"]` will be overriden.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a392c12f9b943d38f0d227d90110fe5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a17ab6cc46204b87b17e3f1be166675e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9efca8e5d2042e0947b8704103fd000",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5aeff66e777146b096da1ceb7bc3b8cd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06e1515d05f745e68efd79878f8902fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20c768105821463688fa1939b7a51f45",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c19166d7f38b4634b804436598200b1c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48ab3882aa2a4462b6cea2bcf1d90596",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6fd078f4cba24341b3c882b5ec614f27",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffc0daf21c17487e8517bcfe45928f1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a8103a971be4467850b216a05f3cab8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5a75872da464ff0953aae1c4aef1767",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5843825a4a9452b834a7c73b8c7a5b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f1464992d5e444d692b08029d57a8c3a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02f8783f6e0c4dc68269e5228103e00d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61fdd058f16d4ee9bfe46e351e1115b5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "392aeb23189544f9b92e5f8fc133d19f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a27edf8721e44238a5fc12831aece1d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "536d46ab37b64bdf94be66ab7904a044",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "870903749d5b4a5aa39b9c1531621067",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ea246cf64a042a3960f80d3486360af",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from diffusers import StableDiffusionPipeline\n",
    "from peft import PeftConfig, PeftModel\n",
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "concepts = os.listdir(\"./data/celebs\")\n",
    "concepts.remove('paths.txt')\n",
    "for concept in concepts:\n",
    "    \n",
    "    out_dir = os.path.join(\"gene/celebs\",concept)\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    \n",
    "    #pipe = StableDiffusionPipeline.from_pretrained(\"/home/yangmingzhao/2024_5/bk-sdm-tiny\", torch_dtype=torch.float16)\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\"/home/yangmingzhao/2024_5/sd_v1_5\", torch_dtype=torch.float16)\n",
    "    text_encoder = pipe.text_encoder\n",
    "    pipe = pipe.to(\"cuda\")\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of car\"\n",
    "    image = pipe(prompt,generator = generator).images[0]  \n",
    "    image.save(os.path.join(out_dir,\"ori_car\"))#\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of \"+ concept\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(os.path.join(out_dir,\"ori_\"+str(i)+\".png\"))\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of a person\"\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(os.path.join(out_dir,\"ori_person_\"+str(i)+\".png\"))\n",
    "\n",
    "    pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(\"/home/yangmingzhao/2024_5/Fedundm/ckpt\",concept))\n",
    "    pipe = pipe.to(\"cuda\")\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of car\"\n",
    "    image = pipe(prompt,generator = generator).images[0]  \n",
    "    image.save(os.path.join(out_dir,\"text_only_car\"))#\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of \"+ concept\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(os.path.join(out_dir,\"text_only_\"+str(i)+\".png\"))\n",
    "\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of a person\"\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(os.path.join(out_dir,\"text_only_person_\"+str(i)+\".png\"))\n",
    "\n",
    "\n",
    "    pipe = StableDiffusionPipeline.from_pretrained(\"/home/yangmingzhao/2024_5/sd_v1_5\", torch_dtype=torch.float16)\n",
    "    text_encoder = pipe.text_encoder\n",
    "    pipe.text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(\"/home/yangmingzhao/2024_5/Fedundm/ckpt\",concept))\n",
    "    state_dict = torch.load(os.path.join(\"/home/yangmingzhao/2024_5/Fedundm/ckpt\"<concept),map_location = 'cpu')\n",
    "    unet_lora_config = LoraConfig(\n",
    "        r=4,\n",
    "        lora_alpha=2,\n",
    "        init_lora_weights=\"gaussian\",\n",
    "        target_modules=[\"to_k\", \"to_q\"],\n",
    "    )\n",
    "    pipe.unet.add_adapter(unet_lora_config) \n",
    "    load_state_dict(pipe.unet,state_dict)\n",
    "    pipe = pipe.to(\"cuda\")\n",
    "\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of car\"\n",
    "    image = pipe(prompt,generator = generator).images[0]  \n",
    "    image.save(\"gene/unet_car.png\")\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of Elon Musk\"\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(\"gene/unet_ElonMusk\"+str(i)+\".png\")\n",
    "\n",
    "\n",
    "    generator = torch.Generator(device=\"cuda\").manual_seed(5)\n",
    "    prompt = \"an image of a person\"\n",
    "    for i in range(10):\n",
    "        image = pipe(prompt,generator = generator).images[0]  \n",
    "        image.save(\"gene/unet_person\"+str(i)+\".png\")\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bd7145e-5108-4ff0-97ef-b4d539ae81ad",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
