{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "initial_id",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:31:32.599824Z",
          "start_time": "2025-03-31T19:31:29.683083Z"
        }
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "import plot_func\n",
        "import FMfuncs\n",
        "import torch\n",
        "\n",
        "\n",
        "import configs\n",
        "import dataset\n",
        "import PGFM\n",
        "import FM2 #i.e., FMDD"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b417625c6f112acd",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:31:32.850748Z",
          "start_time": "2025-03-31T19:31:32.847529Z"
        }
      },
      "outputs": [],
      "source": [
        "1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c83600ef75a23d2a",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:31:32.939465Z",
          "start_time": "2025-03-31T19:31:32.906814Z"
        }
      },
      "outputs": [],
      "source": [
        "data = np.load(\"./data/projected_10DGaussian.npy\")\n",
        "data = torch.tensor(data, dtype=torch.float32, device = configs.device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7dacbb770487006f",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:31:32.969054Z",
          "start_time": "2025-03-31T19:31:32.965200Z"
        }
      },
      "outputs": [],
      "source": [
        "FMclass = FMfuncs.OTFlowMatching()\n",
        "PGFMclass = PGFM.PGFM()\n",
        "FM2class =FM2.FM2()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b608d22aa4643e72",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-28T17:39:20.100738Z",
          "start_time": "2025-03-28T17:39:17.002268Z"
        }
      },
      "outputs": [],
      "source": [
        "FMclass.train(data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7f60d5da429f27ad",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:34:10.200916Z",
          "start_time": "2025-03-31T19:31:33.025052Z"
        }
      },
      "outputs": [],
      "source": [
        "# PGFMclass.train2_2stage(data, './saved_model/FM_subspace_500000.pth')\n",
        "FM2class.train2_2stage(data, './saved_model/FM_subspace_500000.pth')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1faba649-e2c6-43b8-9480-32114bdb71e6",
      "metadata": {},
      "outputs": [],
      "source": [
        "PGFMclass.train2_2stage(data, './saved_model/FM_subspace_500000.pth')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bef608155106edc6",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:35:27.796661Z",
          "start_time": "2025-03-31T19:35:27.790754Z"
        }
      },
      "outputs": [],
      "source": [
        "ckpt1 = torch.load('./saved_model/FM_subspace_500000.pth', map_location=configs.device, weights_only=True)\n",
        "# ckpt2 = torch.load('./saved_model/PGFM_subspace_20000.pth', map_location=configs.device, weights_only=True)\n",
        "ckpt2 = torch.load('./saved_model/FM2_subspace_20000.pth', map_location=configs.device, weights_only=True)\n",
        "# PGFM_subspace_Mar27"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1a1076db4173b3",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:35:40.866297Z",
          "start_time": "2025-03-31T19:35:40.721652Z"
        }
      },
      "outputs": [],
      "source": [
        "PGFMclass = FM2class\n",
        "\n",
        "stage1model = PGFMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "stage2model = PGFMclass.policy\n",
        "stage2model.load_state_dict(ckpt2)\n",
        "res = PGFMclass.PGFMsample_train2(stage1model, stage2model, 10000)\n",
        "\n",
        "\n",
        "mycoeff = np.ones(configs.d_model+1)\n",
        "mycoeff[-1] = 10\n",
        "print(dataset.distance_data_to_plane(res.cpu().numpy(), mycoeff))\n",
        "print(plot_func.compute_SWD(data[:10000].cpu().numpy(), res.cpu().numpy()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7657812310799201",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-27T17:34:47.333299Z",
          "start_time": "2025-03-27T17:34:47.166794Z"
        }
      },
      "outputs": [],
      "source": [
        "stage1model = FMclass.get_untrained_model()\n",
        "stage1model.load_state_dict(ckpt1)\n",
        "\n",
        "\n",
        "res = FMfuncs.sampler(stage1model, 10000, stoptime=1)\n",
        "mycoeff = np.ones(configs.d_model+1)\n",
        "mycoeff[-1] = 10\n",
        "print(dataset.distance_data_to_plane(res.cpu().numpy(), mycoeff))\n",
        "print(plot_func.compute_SWD(data[:10000].cpu().numpy(), res.cpu().numpy()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f866da5d-e500-4fa1-83ba-3275133bb968",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:36:28.919611Z",
          "start_time": "2025-03-31T19:36:02.382Z"
        }
      },
      "outputs": [],
      "source": [
        "tolerance = 5e-4\n",
        "n_trails = 100\n",
        "avg_distance_FM = np.zeros(n_trails)\n",
        "avg_distance_PGFM = np.zeros(n_trails)\n",
        "SWD_FM = np.zeros(n_trails)\n",
        "SWD_PGFM = np.zeros(n_trails)\n",
        "invalid_num_FM = np.zeros(n_trails)\n",
        "invalid_num_PGFM = np.zeros(n_trails)\n",
        "mycoeff = np.ones(configs.d_model+1)\n",
        "mycoeff[-1] = 10\n",
        "\n",
        "for i in range(n_trails):\n",
        "    if (i+1)%10==0:\n",
        "        print(i+1)\n",
        "    random_index = np.random.choice(data.shape[0], 10000, replace=False)\n",
        "    resPGFM = PGFMclass.PGFMsample_train2(stage1model, stage2model, 10000)\n",
        "    resFM = FMfuncs.sampler(stage1model, 10000, stoptime=1)\n",
        "    \n",
        "    distancePGFM_batch, distancePGFM = dataset.distance_data_to_plane(resPGFM.cpu().numpy(), mycoeff)\n",
        "    distanceFM_batch, distanceFM = dataset.distance_data_to_plane(resFM.cpu().numpy(), mycoeff)\n",
        "    \n",
        "    invalid_num_FM[i] = np.sum(distanceFM_batch > tolerance)\n",
        "    invalid_num_PGFM[i] = np.sum(distancePGFM_batch > tolerance)\n",
        "    \n",
        "    avg_distance_PGFM[i] = distancePGFM\n",
        "    avg_distance_FM[i] = distanceFM\n",
        "\n",
        "    SWDtmpPGFM = plot_func.compute_SWD(data[random_index].cpu().numpy(), resPGFM.cpu().numpy())\n",
        "    SWDtmpFM = plot_func.compute_SWD(data[random_index].cpu().numpy(), resFM.cpu().numpy())\n",
        "    SWD_FM[i] = SWDtmpFM\n",
        "    SWD_PGFM[i] = SWDtmpPGFM\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "38028c52-2225-4d37-be10-1fe72f4e0fd3",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-31T19:36:28.944719Z",
          "start_time": "2025-03-31T19:36:28.939779Z"
        }
      },
      "outputs": [],
      "source": [
        "print(\" FM SWD: \", f\"{np.mean( SWD_FM):.7f}\", r\"\\pm\", f\"{np.std( SWD_FM):.7f}\")\n",
        "print(\" FM num out: \", f\"{np.mean( invalid_num_FM):.7f}\", r\"\\pm\", f\"{np.std( invalid_num_FM):.7f}\")\n",
        "print(\"FM avg distance: \", f\"{np.mean(avg_distance_FM):.7f}\", r\"\\pm\", f\"{np.std(avg_distance_FM):.7f}\")\n",
        "\n",
        "print(\"PGFM SWD: \", f\"{np.mean(SWD_PGFM):.4f}\", r\"\\pm\", f\"{np.std(SWD_PGFM):.4f}\")\n",
        "print(\"PGFM num out: \", f\"{np.mean(invalid_num_PGFM):.4f}\", r\"\\pm\", f\"{np.std(invalid_num_PGFM):.4f}\")\n",
        "print(\"PGFM avg distance: \", f\"{np.mean(avg_distance_PGFM):.7f}\", r\"\\pm\", f\"{np.std(avg_distance_PGFM):.7f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bd8338f8-1543-46fc-b05e-cff9804a7712",
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-03-27T17:25:47.014123Z",
          "start_time": "2025-03-27T17:25:47.011425Z"
        }
      },
      "outputs": [],
      "source": [
        "distanceFM"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0161c324-e930-4952-b9b7-68365204d413",
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}