{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b04afb09",
   "metadata": {},
   "source": [
    "# Watermark Training and Evaluation Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "242e7f03",
   "metadata": {},
   "source": [
    "## Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67838602",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "# Alternative src path depending on the jupyter root path\n",
    "sys.path.append(str(Path().resolve().parent))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d434aff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.data import Dataset, Augment\n",
    "from src.models import WatermarkScoreModel, Watermark, ModelWrapper, DiffusionModel, get_vae, decode\n",
    "from src.training import Trainer, AugSampler\n",
    "from src.evaluation import Evaluation, visualize, generate, scores\n",
    "from src.utils.utils import cleanup_cuda_memory\n",
    "from src.utils import Config\n",
    "\n",
    "config = Config('config.yaml')\n",
    "device = config.get_device()\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "get_vae(config)\n",
    "cleanup_cuda_memory(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccb55251",
   "metadata": {},
   "source": [
    "## Initialize models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a64817b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "augment = Augment()\n",
    "dataset = Dataset(config)\n",
    "\n",
    "diffusion_model = ModelWrapper(DiffusionModel(config).to(device)).to(device)\n",
    "watermark = Watermark(config).to(device)\n",
    "\n",
    "score_model = WatermarkScoreModel(config).to(device)\n",
    "\n",
    "print(\"Watermark Score Model:\")\n",
    "print(f\"Total parameters: {sum(p.numel() for p in score_model.parameters()):,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70b8a3b5",
   "metadata": {},
   "source": [
    "## Create Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "697204ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Needs to be done only once for each model version\n",
    "generated_data = dataset.generate_data(\n",
    "    diffusion_model=diffusion_model,\n",
    "    use_test_prompts=False\n",
    ")\n",
    "\n",
    "# Alternatively, augment existing data:\n",
    "#generated_data = Dataset.load_dataset(config, config.dataset.paths.generated_data_file)\n",
    "\n",
    "augmented_data = dataset.generate_augmented_data(\n",
    "    data=generated_data, \n",
    "    augmentation_pipeline=augment\n",
    ")\n",
    "\n",
    "dataset.save_dataset()\n",
    "print(\"Data saved!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c665c88f",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c04f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_gen, train_data_aug = Dataset.load_data_pair(config)\n",
    "train_prompts = Dataset.get_prompts(config, use_test=False)\n",
    "test_prompts = Dataset.get_prompts(config, use_test=True)\n",
    "print(f\"Sample prompt: {train_prompts[17]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23201eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the data\n",
    "Dataset.visualize(train_data_gen, train_data_aug, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0461f4b2",
   "metadata": {},
   "source": [
    "## Train the Watermark Detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6585be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize trainer\n",
    "aug_sampler = AugSampler(len(augment))\n",
    "trainer = Trainer(\n",
    "    config=config,\n",
    "    diffusion_model=diffusion_model,\n",
    "    score_model=score_model,\n",
    "    watermark=watermark,\n",
    "    aug_sampler=aug_sampler,\n",
    "    data=train_data_gen,\n",
    "    data_aug=train_data_aug,\n",
    "    prompts=train_prompts,\n",
    "    augment=augment\n",
    ")\n",
    "\n",
    "print(\"Trainer initialized!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8736972",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Watermark training\n",
    "print(\"\\n=== WATERMARK TRAINING ===\")\n",
    "train_losses = trainer.train(epochs=1, verbose=True)\n",
    "print(f\"Training complete. Final loss: {train_losses[-1][0]:.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afc74142",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the models\n",
    "score_model.save()\n",
    "watermark.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97680fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_cuda_memory(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2828d4fb",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34cc5555",
   "metadata": {},
   "outputs": [],
   "source": [
    "score_model = WatermarkScoreModel.load(config).requires_grad_(False).eval().to(device)\n",
    "watermark = Watermark.load(config).requires_grad_(False).eval().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd142625",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alternatively with a checkpoint:\n",
    "score_model, watermark = Trainer.load_models_from_checkpoint(config, \"./results/SERUM_demo/checkpoints/checkpoint_epoch_1.pt\")\n",
    "score_model = score_model.requires_grad_(False).eval().to(device)\n",
    "watermark = watermark.requires_grad_(False).eval().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd5df948",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_cuda_memory(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "092e85bb",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f5c3816",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize evaluator\n",
    "evaluator = Evaluation(\n",
    "    config=config,\n",
    "    diffusion_model=diffusion_model,\n",
    "    score_model=score_model,\n",
    "    watermark=watermark,\n",
    "    prompts=test_prompts,\n",
    "    augment=augment\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ed12b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Quick evaluation\n",
    "print(\"\\n=== QUICK EVALUATION ===\")\n",
    "quick_results = evaluator.quick_eval(show_images=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "058ca33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Thorough evaluation with ROC\n",
    "print(\"\\n=== EVALUATION ===\")\n",
    "non_aug_results = evaluator.eval_tpr(\n",
    "    num_samples=50\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6c09f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot results\n",
    "evaluator.plot_results()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "536fa2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# After running thorough_eval\n",
    "augmentation_results = evaluator.augmentation_robustness_eval(\n",
    "    num_samples=50,\n",
    "    #augmentation_ids=[0, 1, 2]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "257ad0dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = evaluator.watermark_detection_time(num_samples=100, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15085926",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_cuda_memory(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8275c06",
   "metadata": {},
   "source": [
    "## FID & CLIP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "366a238d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load prompts and COCO images\n",
    "generate.load_coco(\n",
    "    config=config,\n",
    "    annotation_path=\"./mscoco2014val/annotations/captions_val2014.json\",\n",
    "    images_path=\"./mscoco2014val/val2014/\",\n",
    "    num_samples=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8757ebce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load COCO images from existing prompts file\n",
    "# (Alternatively, you can provide your own prompts file from COCO)\n",
    "generate.load_coco_images_from_prompts(\n",
    "    config=config,\n",
    "    annotation_path=\"./mscoco2014val/annotations/captions_val2014.json\",\n",
    "    images_path=\"./mscoco2014val/val2014/\",\n",
    "    prompts_path=\"./mscoco2014val/prompts.txt\",\n",
    "    num_samples=100\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06e9216",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Based on the loaded prompts, generate images with and without watermark\n",
    "generate.generate(\n",
    "    config=config,\n",
    "    diffusion_model=diffusion_model,\n",
    "    watermark=watermark,\n",
    "    num_samples=100,\n",
    "    force_reload=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c60f16ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores.visualize_samples(config, num_samples=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800ed006",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = scores.calculate_fid(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6339bfc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = scores.calculate_clip_score(\n",
    "    config=config,\n",
    "    clip_model_name=\"ViT-B/32\",\n",
    "    batch_size=64\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f8a4ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleanup_cuda_memory(True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "serum",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
