{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c125503",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, random, os\n",
    "from datasets import load_dataset, DatasetDict\n",
    "from diffusers import  StableDiffusionXLImg2ImgPipeline, LCMScheduler\n",
    "from PIL import Image\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "# ----------  Hyper-parameters  ----------\n",
    "HOLDOUT_FRAC   = 0.75          \n",
    "GEN_RES      = 512        \n",
    "FINAL_RES    = 32           \n",
    "STRENGTH       = 0.7         \n",
    "NUM_SYNTH_PER  = 1             # 1 synthetic  1 real\n",
    "SEED           = 42\n",
    "OUT_DIR        = Path(\"./flowers_augmented\")\n",
    "MODEL_DIR      = \"./vega\"\n",
    "DEVICE         = \"cuda\"         \n",
    "\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "OUT_DIR.mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "Path('flowers_heldout').mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54881312",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = pd.read_csv('oxford_flower_102_name.csv')\n",
    "label_dict ={}\n",
    "for i, r in df.iterrows():\n",
    "    label_dict[r['Index']] = r['Name']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daa0615c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------- 1.  Load & split ----------\n",
    "ds = load_dataset(\"dpdl-benchmark/oxford_flowers102\")\n",
    "train_test = ds[\"test\"].train_test_split(test_size=HOLDOUT_FRAC, seed=SEED)  \n",
    "augment_set, holdout_set = train_test[\"train\"], train_test[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3d0c6a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i,ex in tqdm(enumerate(holdout_set)):\n",
    "    init32 = ex[\"image\"].convert(\"RGB\")\n",
    "    init256 = init32.resize((GEN_RES, GEN_RES), Image.NEAREST)\n",
    "\n",
    "    cls_name = label_dict[ex[\"label\"]]\n",
    "\n",
    "    l = ex['label']\n",
    "    img_id = f'{i}_{cls_name}'\n",
    "    init256.save('flowers_heldout' + '/' + f\"{img_id}_real.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd91d431",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n",
    "    \"segmind/SSD-1B\",\n",
    "    torch_dtype=torch.float16,\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "748b4fba",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for i,ex in enumerate(augment_set):\n",
    "    init32 = ex[\"image\"].convert(\"RGB\")\n",
    "    init256 = init32.resize((GEN_RES, GEN_RES), Image.NEAREST)\n",
    "\n",
    "    cls_name = label_dict[ex[\"label\"]]\n",
    "    # ---------- prompt engineering tips ----------\n",
    "    prompt = (\n",
    "    f\"Photorealistic photograph of a single {cls_name}, \"\n",
    "    f\"realistic colors, natural lighting, high detail, sharp focus on petals. \"\n",
    "    f\"Another unique photo of the same flower species.\"\n",
    "    )\n",
    "\n",
    "    # Negative Prompt: Explicitly fight saturation\n",
    "    negative_prompt = (\n",
    "    \"oversaturated, highly saturated, neon colors, garish colors, vibrant colors, \" # Target saturation\n",
    "    \"illustration, painting, drawing, sketch, cartoon, anime, unrealistic, \" # Target style\n",
    "    \"blurry, low quality, text, watermark, signature, border, frame, multiple flowers\" # Target artifacts/composition\n",
    "    )\n",
    "    # ----------------------------------------------\n",
    "    gen = pipe(prompt=prompt,\n",
    "               negative_prompt=negative_prompt,\n",
    "               image=init256,\n",
    "               strength=STRENGTH,\n",
    "               guidance_scale=6,\n",
    "               num_inference_steps=50\n",
    "             ).images[0]\n",
    "\n",
    "\n",
    "    l = ex['label']\n",
    "    img_id = f'{i}_{cls_name}'\n",
    "\n",
    "\n",
    "print(\"Done - saved real & synthetic 32x32 images.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
