{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:38:56.715535Z",
          "start_time": "2025-04-30T16:38:51.586093Z"
        }
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision\n",
        "import torchvision.transforms as transforms\n",
        "import cv2\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import load\n",
        "from PIL.ImageEnhance import Brightness\n",
        "from compute_thickness import thickness_batch, brightness_batch\n",
        "from FM_funcs import OTFlowMatching\n",
        "import FM_funcs\n",
        "import configs\n",
        "\n",
        "import torch.nn.functional as F\n",
        "# import PGFM_funcs\n",
        "from PGFMv2_funcs import PGFM\n",
        "import tqdm\n",
        "import evaluate\n",
        "from fid.fid_compute import compute_fid_from_batches"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d9cf3bff9b7f2403",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:38:56.748241Z",
          "start_time": "2025-04-30T16:38:56.744685Z"
        }
      },
      "outputs": [],
      "source": [
        "2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "46b01215b27415b8",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:38:56.863826Z",
          "start_time": "2025-04-30T16:38:56.826380Z"
        }
      },
      "outputs": [],
      "source": [
        "FM_class = OTFlowMatching()\n",
        "PGFM_class = PGFM()\n",
        "mnist_dataset = load.myMMIST()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6d6d032b483a451d",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:04.420949Z",
          "start_time": "2025-04-30T16:39:01.760538Z"
        }
      },
      "outputs": [],
      "source": [
        "train_set = mnist_dataset.get_data(plot_hist=True).to(configs.device)\n",
        "MNIST_distribution_evaluator = evaluate.distribution_evaluator(train_set)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "28b071cccff15318",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T20:55:48.868499Z",
          "start_time": "2025-04-25T20:55:45.176795Z"
        }
      },
      "outputs": [],
      "source": [
        "train_set = mnist_dataset.get_data_brightness(plot_hist=True).to(configs.device)\n",
        "MNIST_distribution_evaluator = evaluate.distribution_evaluator(train_set)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7e5706417a378395",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:04.464012Z",
          "start_time": "2025-04-30T16:39:04.457868Z"
        }
      },
      "outputs": [],
      "source": [
        "def eval(distribution_evaluator, data_set = train_set, batch_size = 1000, rep_num = 10, type = \"brightness\"):\n",
        "    num_record_FM = np.zeros(rep_num)\n",
        "    num_record_PGFM = np.zeros(rep_num)\n",
        "    prob_record_FM = np.zeros(rep_num)\n",
        "    prob_record_PGFM = np.zeros(rep_num)\n",
        "    swd_record_FM = np.zeros(rep_num)\n",
        "    swd_record_PGFM = np.zeros(rep_num)\n",
        "\n",
        "    for i in tqdm.tqdm(range(rep_num)):\n",
        "        FM_res = FM_funcs.sampler(stage1model, batch_size, stoptime=1, default_generation_step = 100)\n",
        "        PGFM_res = PGFM_class.RLFMsample( stage1model, stage2model, batch_size, default_generation_step= 100,\n",
        "                   default_stage1t =configs.default_stage1_t, default_RLstep_S = configs.default_RL_Steps_S)\n",
        "\n",
        "        if type == \"brightness\":\n",
        "            FM_brightness_vec = brightness_batch(FM_res.cpu())\n",
        "            PGFM_brightness_vec = brightness_batch(PGFM_res.cpu())\n",
        "            ind_FM,num_FM = configs.count_valid_num_bright(FM_brightness_vec)\n",
        "            ind_PGFM, num_PGFM = configs.count_valid_num_bright(PGFM_brightness_vec)\n",
        "            # print(num_PGFM)\n",
        "        elif type == \"thickness\":\n",
        "            FM_thickness_vec = thickness_batch(FM_res.cpu())\n",
        "            PGFM_thickness_vec = thickness_batch(PGFM_res.cpu())\n",
        "            ind_FM, num_FM = configs.count_valid_num_thick(FM_thickness_vec)\n",
        "            ind_PGFM, num_PGFM = configs.count_valid_num_thick(PGFM_thickness_vec)\n",
        "        else:\n",
        "            raise NotImplementedError\n",
        "\n",
        "        prob_record_FM[i] = num_FM / batch_size\n",
        "        prob_record_PGFM[i] = num_PGFM / batch_size\n",
        "        num_record_FM[i] = num_FM\n",
        "        num_record_PGFM[i] = num_PGFM\n",
        "\n",
        "        ref = PGFM_class.get_samples(data_set, batch_size)\n",
        "        swd_FM = distribution_evaluator.SWD_after_PCA(ref, FM_res)\n",
        "        swd_PGFM = distribution_evaluator.SWD_after_PCA(ref, FM_res)\n",
        "\n",
        "        swd_record_FM[i] = swd_FM\n",
        "        swd_record_PGFM[i] = swd_PGFM\n",
        "\n",
        "    print(\"FM prob mean:\", np.mean(prob_record_FM), \"std:\", np.std(prob_record_FM))\n",
        "    print(\"PGFM mean:\", np.mean(prob_record_PGFM), \"std:\", np.std(prob_record_PGFM))\n",
        "    print(\"FM num mean:\", np.mean(num_record_FM), \"std:\", np.std(num_record_FM))\n",
        "    print(\"PGFM num mean:\", np.mean(num_record_PGFM), \"std:\", np.std(num_record_PGFM))\n",
        "    print(\"FM SWD mean:\", np.mean(swd_record_FM), \"std:\", np.std(swd_record_FM))\n",
        "    print(\"PGFM SWD mean:\", np.mean(swd_record_PGFM), \"std:\", np.std(swd_record_PGFM))\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6aaea86d0267a131",
      "metadata": {},
      "source": [
        "### Thickness"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9f1661dd5c126bef",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:07.047046Z",
          "start_time": "2025-04-30T16:39:07.024005Z"
        }
      },
      "outputs": [],
      "source": [
        "ckpt1 = torch.load('./saved_model/FM_MNIST_iter_200000_23.pth', map_location=configs.device, weights_only=True)\n",
        "# FM_MNIST_bright_iter_200000: bright\n",
        "# FM_MNIST_iter_200000_23: thick\n",
        "\n",
        "ckpt2 = torch.load('./saved_model/Apr15RLFM_MNIST_thick_06s20_iter_train2_100000.pth', map_location=configs.device, weights_only=True)\n",
        "# Apr17RLFM_MNIST_thick_00s60_iter_train2_100000\n",
        "# Apr17RLFM_MNIST_thick_06s20_iter_train2_60000\n",
        "#RLFM_MNIST_thick_iter_train2_Mar26\n",
        "# RLFM_MNIST_thick_06s40_iter_train2_95000\n",
        "stage1model = FM_class.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "stage2model = PGFM_class.policy\n",
        "stage2model.load_state_dict(ckpt2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "41d37834c6dbdc8b",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:10.228828Z",
          "start_time": "2025-04-30T16:39:07.694849Z"
        }
      },
      "outputs": [],
      "source": [
        "eval(MNIST_distribution_evaluator, type = \"thickness\", rep_num = 100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a8630a08a468162e",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-18T17:54:27.126441Z",
          "start_time": "2025-04-18T17:54:27.123391Z"
        }
      },
      "outputs": [],
      "source": [
        "1000 - 766.43"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f99a92a2983dd884",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:33.556048Z",
          "start_time": "2025-04-30T16:39:12.892828Z"
        }
      },
      "outputs": [],
      "source": [
        "res = FM_funcs.sampler(stage1model, 30000, stoptime=1, default_generation_step = 100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f1ad4d70d275c194",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:39:36.596400Z",
          "start_time": "2025-04-30T16:39:33.574755Z"
        }
      },
      "outputs": [],
      "source": [
        "# res = train_set\n",
        "thickness_vec = thickness_batch(res.cpu())\n",
        "unique_vals, counts = np.unique(thickness_vec, return_counts=True)\n",
        "\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.bar(unique_vals, counts, width=0.1, edgecolor='black')\n",
        "plt.xlabel(\"Thickness Value\")\n",
        "plt.ylabel(\"Frequency\")\n",
        "plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
        "ind,num = configs.count_valid_num_thick(thickness_vec)\n",
        "plt.title(\"Valid num: \" + str(num))\n",
        "plt.xlim(0.9,4.5)\n",
        "plt.savefig(\"./figs/thickness_x.png\", dpi=300)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cf2dc693b1b94593",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T16:46:15.644472Z",
          "start_time": "2025-04-30T16:46:13.248673Z"
        }
      },
      "outputs": [],
      "source": [
        "ref = PGFM_class.get_samples(train_set, 21405)\n",
        "\n",
        "ref = F.interpolate(torch.clip(ref,0,1), size=(32, 32), mode='bilinear', align_corners=False)\n",
        "res = F.interpolate(torch.clip(res,0,1), size=(32, 32), mode='bilinear', align_corners=False)\n",
        "# res = torch.rand_like(res)\n",
        "compute_fid_from_batches(res, ref)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1cd4fd2fac334b96",
      "metadata": {
        "ExecuteTime": {
          "start_time": "2025-04-30T16:46:26.888726Z"
        },
        "jupyter": {
          "is_executing": true
        }
      },
      "outputs": [],
      "source": [
        "res = PGFM_class.RLFMsample( stage1model, stage2model, 10000, default_generation_step= 60,\n",
        "                   default_stage1t =0.6, default_RLstep_S = 40)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b9157342a89e104a",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T14:35:14.095736Z",
          "start_time": "2025-04-30T14:35:11.205989Z"
        }
      },
      "outputs": [],
      "source": [
        "thickness_vec = thickness_batch(res.cpu())\n",
        "unique_vals, counts = np.unique(thickness_vec, return_counts=True)\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.bar(unique_vals, counts, width=0.1, edgecolor='black')\n",
        "plt.xlabel(\"Thickness Value\")\n",
        "plt.ylabel(\"Frequency\")\n",
        "plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
        "ind,num = configs.count_valid_num_thick(thickness_vec)\n",
        "plt.title(\"Valid num: \" + str(num))\n",
        "plt.xlim(0.9,4.5)\n",
        "plt.savefig(\"./figs/thickness_PGFM.png\", dpi=300)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c28409f436950044",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-30T14:31:19.813655Z",
          "start_time": "2025-04-30T14:31:19.501269Z"
        }
      },
      "outputs": [],
      "source": [
        "# res = xstage2_N1TD\n",
        "rand_ind = torch.randperm(len(res))\n",
        "display_num = 5\n",
        "plt.figure(figsize=(12,6))\n",
        "for i in range(display_num):\n",
        "    plt.subplot(3, display_num, i+1)\n",
        "\n",
        "\n",
        "    # image_np1 = np.clip(res[rand_ind[i]].cpu().numpy()[0],0,1 )\n",
        "    image_np1 = res[rand_ind[i]].cpu().numpy()[0]\n",
        "    plt.imshow(image_np1, cmap='gray')\n",
        "    plt.axis('off')\n",
        "\n",
        "    image_np0 = (np.clip(res[rand_ind[i]].cpu().numpy()[0],0,1) * 255).astype(np.uint8)\n",
        "    # image_np0 = (res[rand_ind[i]].cpu().numpy()[0] * 255).astype(np.uint8)\n",
        "    image_np = cv2.threshold(image_np0, 128, 255, cv2.THRESH_BINARY)[1]\n",
        "    plt.subplot(3, display_num, i+6)\n",
        "    plt.imshow(image_np, cmap='gray')\n",
        "    plt.grid(False)\n",
        "    plt.axis('off')\n",
        "    plt.title(f\"{thickness_vec[rand_ind[i]]:.3f}\")\n",
        "\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "12cbaeae626d43d6",
      "metadata": {},
      "source": [
        "### Brightness"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1b0519f60ebb9c1d",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T20:46:59.536532Z",
          "start_time": "2025-04-25T20:46:59.479537Z"
        }
      },
      "outputs": [],
      "source": [
        "ckpt1 = torch.load('./saved_model/FM_MNIST_bright_iter_200000.pth', map_location=configs.device, weights_only=True)\n",
        "# FM_MNIST_bright_iter_200000: bright\n",
        "# FM_MNIST_iter_200000_23: thick\n",
        "\n",
        "ckpt2 = torch.load('./saved_model/Apr25RLFM_MNIST_bright_06s20_iter_train2_20000.pth', map_location=configs.device, weights_only=True)\n",
        "# RLFM_MNIST_bright_iter_train2_Mar26\n",
        "#RLFM_MNIST_bright_06s20_iter_train2_Mar28\n",
        "stage1model = FM_class.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "stage2model = PGFM_class.policy\n",
        "stage2model.load_state_dict(ckpt2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "32c2318a0b0adda5",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:02:02.710753Z",
          "start_time": "2025-04-25T21:02:01.165011Z"
        }
      },
      "outputs": [],
      "source": [
        "eval(MNIST_distribution_evaluator, type = \"brightness\", rep_num = 100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "58acbee9768457b0",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:06:22.410669Z",
          "start_time": "2025-04-25T21:06:21.520750Z"
        }
      },
      "outputs": [],
      "source": [
        "ref = PGFM_class.get_samples(train_set, 30379)\n",
        "\n",
        "ref = F.interpolate(ref, size=(32, 32), mode='bilinear', align_corners=False)\n",
        "res = F.interpolate(res, size=(32, 32), mode='bilinear', align_corners=False)\n",
        "\n",
        "compute_fid_from_batches(res, ref)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "51e00288a36602cc",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T20:06:58.562714Z",
          "start_time": "2025-04-25T20:06:45.924715Z"
        }
      },
      "outputs": [],
      "source": [
        "res = FM_funcs.sampler(stage1model, 10000, stoptime=1, default_generation_step = 100)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3eeaace8c3ced482",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T20:07:00.682893Z",
          "start_time": "2025-04-25T20:06:58.582873Z"
        }
      },
      "outputs": [],
      "source": [
        "# res = train_set\n",
        "brightness_vec = brightness_batch(res.cpu())\n",
        "interval_mid = np.linspace(60, 250, 20)\n",
        "\n",
        "# for interval in interval_mid:\n",
        "bin_edges = interval_mid\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.hist(brightness_vec, bins=bin_edges, edgecolor='black', align='mid')\n",
        "plt.xlabel(\"brightness Value\")\n",
        "plt.ylabel(\"Frequency\")\n",
        "plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
        "ind,num = configs.count_valid_num_bright(brightness_vec)\n",
        "plt.title(\"Valid num: \" + str(num))\n",
        "plt.xlim(60, 250)\n",
        "plt.savefig(\"./figs/brightness_x.png\", dpi=300)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f178f399b281b7fd",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:06:21.500805Z",
          "start_time": "2025-04-25T21:06:13.218200Z"
        }
      },
      "outputs": [],
      "source": [
        "\n",
        "res = PGFM_class.RLFMsample(stage1model, stage2model, 10000, default_generation_step=100,\n",
        "                            default_stage1t=0.6, default_RLstep_S=40)\n",
        "# res = PGFM_class.RLFMsample( stage1model, stage2model, 1000, default_generation_step= 100,\n",
        "#                    default_stage1t =configs.default_stage1_t, default_RLstep_S = configs.default_RL_Steps_S)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4aa36ef70251753c",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:04:17.246459Z",
          "start_time": "2025-04-25T21:04:16.948623Z"
        }
      },
      "outputs": [],
      "source": [
        "brightness_vec = brightness_batch(res.cpu())\n",
        "interval_mid = np.linspace(60, 250, 20)\n",
        "\n",
        "# for interval in interval_mid:\n",
        "bin_edges = interval_mid\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.hist(brightness_vec, bins=bin_edges, edgecolor='black', align='mid')\n",
        "plt.xlabel(\"brightness Value\")\n",
        "plt.ylabel(\"Frequency\")\n",
        "plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
        "ind,num = configs.count_valid_num_bright(brightness_vec)\n",
        "plt.title(\"Valid num: \" + str(num))\n",
        "plt.xlim(60, 250)\n",
        "plt.savefig(\"./figs/brightness_PGFM.png\", dpi=300)\n",
        "\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dd570389f522cf25",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-04-25T21:04:19.429152Z",
          "start_time": "2025-04-25T21:04:19.248197Z"
        }
      },
      "outputs": [],
      "source": [
        "# res = xstage2_N1TD\n",
        "rand_ind = torch.randperm(len(res))\n",
        "display_num = 5\n",
        "plt.figure(figsize=(12,6))\n",
        "for i in range(display_num):\n",
        "    plt.subplot(3, display_num, i+1)\n",
        "\n",
        "\n",
        "    # image_np1 = np.clip(res[rand_ind[i]].cpu().numpy()[0],0,1 )\n",
        "    image_np1 = res[rand_ind[i]].cpu().numpy()[0]\n",
        "    plt.imshow(image_np1, cmap='gray')\n",
        "    plt.axis('off')\n",
        "\n",
        "    image_np = (np.clip(res[rand_ind[i]].cpu().numpy()[0],0,1) * 255).astype(np.uint8)\n",
        "    # image_np = (res[rand_ind[i]].cpu().numpy()[0] * 255).astype(np.uint8)\n",
        "    image_np = cv2.threshold(image_np, 128, 255, cv2.THRESH_BINARY)[1]\n",
        "    plt.subplot(3, display_num, i+6)\n",
        "    plt.imshow(image_np, cmap='gray')\n",
        "    plt.grid(False)\n",
        "    plt.axis('off')\n",
        "    plt.title(f\"{brightness_vec[rand_ind[i]]:.3f}\")\n",
        "\n",
        "\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}